package grib import ( "context" "errors" "fmt" "io" "math" "net/http" "os" "sync" "time" "github.com/nilsmagnus/grib/griblib" "go.uber.org/zap" "golang.org/x/sync/errgroup" "predictor-refactored/internal/datasets" wgfs "predictor-refactored/internal/weather/gfs" ) // URLFunc returns the GRIB URL for one (date, runHour, member, step, levelSet). // Sources that don't have members (GFS) ignore the member argument. type URLFunc func(date string, runHour, member, step int, ls wgfs.LevelSet) string // Downloader is the generic GRIB-cube downloader. // // A Source plugs in its variant, URL templating, and member-resolution // logic; the Downloader runs the parallel idx fetch, byte-range download, // GRIB decode, and blit loop with manifest-based resume. type Downloader struct { Variant *wgfs.Variant URLs URLFunc Parallel int Client *http.Client Log *zap.Logger } func (d *Downloader) log() *zap.Logger { if d.Log == nil { return zap.NewNop() } return d.Log } func (d *Downloader) client() *http.Client { if d.Client == nil { return &http.Client{Timeout: 2 * time.Minute} } return d.Client } func (d *Downloader) parallel() int { if d.Parallel <= 0 { return 8 } return d.Parallel } // neededVariables is the GRIB variable set every source extracts. var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true} // Run downloads the dataset for id, member into store. The caller may // pass member=0 for non-ensemble sources. func (d *Downloader) Run(ctx context.Context, id datasets.DatasetID, member int, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error { if prog == nil { prog = noopSink{} } handle, err := store.BeginWrite(id) if err != nil { return fmt.Errorf("begin write: %w", err) } manifest := handle.Manifest() file, err := openOrCreateCube(handle.Path(), d.Variant) if err != nil { _ = handle.Abort() return err } epoch := id.Epoch.UTC() date := epoch.Format("20060102") runHour := epoch.Hour() steps := d.Variant.Hours() if hr := id.Subset.HourRange; hr != nil { filtered := steps[:0] for _, step := range steps { if step >= hr.MinHour && step <= hr.MaxHour { filtered = append(filtered, step) } } steps = filtered } prog.SetTotal(len(steps) * 2) for range manifest.Units() { prog.StepComplete() } start := time.Now() g, ctx := errgroup.WithContext(ctx) g.SetLimit(d.parallel()) var fileMu sync.Mutex for _, step := range steps { hourIdx := d.Variant.HourIndex(step) if hourIdx < 0 { continue } for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} { unit := unitKey(step, ls) if manifest.Has(unit) { continue } g.Go(func() error { url := d.URLs(date, runHour, member, step, ls) if err := d.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil { return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err) } if err := manifest.Mark(unit); err != nil { return fmt.Errorf("mark unit: %w", err) } prog.StepComplete() return nil }) } } if err := g.Wait(); err != nil { _ = file.Close() if errors.Is(err, context.Canceled) { return err } if len(manifest.Units()) == 0 { _ = handle.Abort() } return err } if err := file.Flush(); err != nil { _ = file.Close() return fmt.Errorf("flush: %w", err) } if err := file.Close(); err != nil { return fmt.Errorf("close: %w", err) } if err := handle.Commit(); err != nil { return fmt.Errorf("commit: %w", err) } d.log().Info("download complete", zap.String("variant", d.Variant.ID), zap.Time("epoch", epoch), zap.Duration("elapsed", time.Since(start))) return nil } // openOrCreateCube opens an existing cube at path if it matches variant's // expected size, else truncate-creates a new one. func openOrCreateCube(path string, variant *wgfs.Variant) (*wgfs.File, error) { info, err := os.Stat(path) if err == nil && info.Size() == variant.DatasetSize() { return wgfs.OpenWritable(path, variant) } if err != nil && !errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("stat cube: %w", err) } return wgfs.Create(path, variant) } // downloadAndBlit fetches and decodes one (URL, level-set) chunk. func (d *Downloader) downloadAndBlit( ctx context.Context, file *wgfs.File, fileMu *sync.Mutex, baseURL string, hourIdx int, ls wgfs.LevelSet, prog datasets.ProgressSink, throttle datasets.Throttle, ) error { idxBody, err := d.httpGet(ctx, baseURL+".idx", throttle, prog) if err != nil { return fmt.Errorf("idx: %w", err) } entries := ParseIdx(idxBody) filtered := FilterIdx(entries, neededVariables) var relevant []IdxEntry for _, e := range filtered { set, ok := d.Variant.PressureLevelSet(e.LevelMB) if ok && set == ls { relevant = append(relevant, e) } } if len(relevant) == 0 { return nil } ranges := EntriesToRanges(relevant) tmp, err := os.CreateTemp("", "grib-msg-*.tmp") if err != nil { return fmt.Errorf("temp: %w", err) } tmpPath := tmp.Name() defer os.Remove(tmpPath) for _, r := range ranges { body, err := d.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog) if err != nil { tmp.Close() return fmt.Errorf("range: %w", err) } if _, err := tmp.Write(body); err != nil { tmp.Close() return fmt.Errorf("write tmp: %w", err) } } if err := tmp.Close(); err != nil { return err } f, err := os.Open(tmpPath) if err != nil { return err } messages, err := griblib.ReadMessages(f) f.Close() if err != nil { return fmt.Errorf("read grib: %w", err) } for _, msg := range messages { if msg.Section4.ProductDefinitionTemplateNumber != 0 { continue } p := msg.Section4.ProductDefinitionTemplate varIdx := d.Variant.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber)) if varIdx < 0 { continue } if p.FirstSurface.Type != 100 { continue } pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0)) levelIdx := d.Variant.PressureIndex(pressureMB) if levelIdx < 0 { continue } data := msg.Data() fileMu.Lock() err := file.BlitGribData(hourIdx, levelIdx, varIdx, data) fileMu.Unlock() if err != nil { return fmt.Errorf("blit: %w", err) } } return nil } func (d *Downloader) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { var lastErr error for attempt := range 3 { if attempt > 0 { select { case <-time.After(time.Duration(attempt*2) * time.Second): case <-ctx.Done(): return nil, ctx.Err() } } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } resp, err := d.client().Do(req) if err != nil { lastErr = err continue } body, err := readThrottled(ctx, resp.Body, throttle, prog) resp.Body.Close() if resp.StatusCode != http.StatusOK { lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url) continue } if err != nil { lastErr = err continue } return body, nil } return nil, fmt.Errorf("after 3 attempts: %w", lastErr) } func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { var lastErr error for attempt := range 3 { if attempt > 0 { select { case <-time.After(time.Duration(attempt*2) * time.Second): case <-ctx.Done(): return nil, ctx.Err() } } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) resp, err := d.client().Do(req) if err != nil { lastErr = err continue } body, err := readThrottled(ctx, resp.Body, throttle, prog) resp.Body.Close() if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { lastErr = fmt.Errorf("HTTP %d for range", resp.StatusCode) continue } if err != nil { lastErr = err continue } return body, nil } return nil, fmt.Errorf("after 3 attempts: %w", lastErr) } func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { buf := make([]byte, 0, 64*1024) chunk := make([]byte, 32*1024) for { if throttle != nil { if err := throttle.Wait(ctx, len(chunk)); err != nil { return nil, err } } n, err := r.Read(chunk) if n > 0 { buf = append(buf, chunk[:n]...) prog.Bytes(int64(n)) } if errors.Is(err, io.EOF) { return buf, nil } if err != nil { return nil, err } } } func unitKey(step int, ls wgfs.LevelSet) string { return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls)) } func levelSetLabel(ls wgfs.LevelSet) string { if ls == wgfs.LevelSetB { return "B" } return "A" } type noopSink struct{} func (noopSink) SetTotal(int) {} func (noopSink) StepComplete() {} func (noopSink) Bytes(int64) {}