package downloader import ( "context" "fmt" "io" "math" "net/http" "os" "path/filepath" "sync/atomic" "time" "predictor-refactored/internal/dataset" "github.com/nilsmagnus/grib/griblib" "go.uber.org/zap" "golang.org/x/sync/errgroup" ) // Downloader handles fetching GFS forecast data from S3 and assembling dataset files. type Downloader struct { cfg *Config client *http.Client log *zap.Logger } // NewDownloader creates a new Downloader. func NewDownloader(cfg *Config, log *zap.Logger) *Downloader { return &Downloader{ cfg: cfg, client: &http.Client{ Timeout: 2 * time.Minute, }, log: log, } } // neededVariables is the set of GRIB variable names we need. var neededVariables = map[string]bool{ "HGT": true, "UGRD": true, "VGRD": true, } // FindLatestRun finds the most recent available GFS model run on S3. // It checks the last forecast step of each run to confirm availability. func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) { now := time.Now().UTC() hour := now.Hour() - (now.Hour() % 6) current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC) for i := 0; i < 8; i++ { date := current.Format("20060102") url := dataset.GribURL(date, current.Hour(), dataset.MaxHour) + ".idx" req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err != nil { current = current.Add(-6 * time.Hour) continue } resp, err := d.client.Do(req) if err == nil { resp.Body.Close() if resp.StatusCode == http.StatusOK { d.log.Info("found latest model run", zap.Time("run", current), zap.String("verified_url", url)) return current, nil } } current = current.Add(-6 * time.Hour) } return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)") } // progress tracks download progress across concurrent goroutines. type progress struct { bytesDownloaded atomic.Int64 stepsCompleted atomic.Int64 totalSteps int64 startTime time.Time log *zap.Logger } func newProgress(totalSteps int, log *zap.Logger) *progress { return &progress{ totalSteps: int64(totalSteps), startTime: time.Now(), log: log, } } func (p *progress) addBytes(n int64) { p.bytesDownloaded.Add(n) } func (p *progress) completeStep() { done := p.stepsCompleted.Add(1) total := p.totalSteps bytes := p.bytesDownloaded.Load() elapsed := time.Since(p.startTime).Seconds() pct := float64(done) / float64(total) * 100 mbDownloaded := float64(bytes) / (1024 * 1024) mbPerSec := 0.0 if elapsed > 0 { mbPerSec = mbDownloaded / elapsed } // Estimate remaining eta := "" if done > 0 && done < total { secsPerStep := elapsed / float64(done) remaining := secsPerStep * float64(total-done) if remaining > 60 { eta = fmt.Sprintf("%.0fm%02.0fs", math.Floor(remaining/60), math.Mod(remaining, 60)) } else { eta = fmt.Sprintf("%.0fs", remaining) } } p.log.Info("download progress", zap.String("progress", fmt.Sprintf("%d/%d", done, total)), zap.String("percent", fmt.Sprintf("%.1f%%", pct)), zap.String("downloaded", fmt.Sprintf("%.1f MB", mbDownloaded)), zap.String("speed", fmt.Sprintf("%.1f MB/s", mbPerSec)), zap.String("eta", eta)) } // Download downloads a complete forecast and assembles a dataset file. // Returns the path to the completed dataset file. func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) { date := run.Format("20060102") runHour := run.Hour() finalPath := filepath.Join(d.cfg.DataDir, run.Format("2006010215")) tempPath := finalPath + ".downloading" // Check if final dataset already exists if info, err := os.Stat(finalPath); err == nil && info.Size() == dataset.DatasetSize { d.log.Info("dataset already exists", zap.String("path", finalPath)) return finalPath, nil } steps := dataset.Hours() totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step prog := newProgress(totalSteps, d.log) d.log.Info("starting dataset download", zap.Time("run", run), zap.Int("total_steps", totalSteps), zap.String("temp_path", tempPath)) // Create the dataset file ds, err := dataset.Create(tempPath) if err != nil { return "", fmt.Errorf("create dataset: %w", err) } defer ds.Close() // Process each forecast step with bounded concurrency g, ctx := errgroup.WithContext(ctx) sem := make(chan struct{}, d.cfg.Parallel) for _, step := range steps { step := step hourIdx := dataset.HourIndex(step) if hourIdx < 0 { continue } // Download pgrb2 (level set A) sem <- struct{}{} g.Go(func() error { defer func() { <-sem }() url := dataset.GribURL(date, runHour, step) err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA, prog) if err != nil { return fmt.Errorf("step %d pgrb2: %w", step, err) } prog.completeStep() return nil }) // Download pgrb2b (level set B) sem <- struct{}{} g.Go(func() error { defer func() { <-sem }() url := dataset.GribURLB(date, runHour, step) err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB, prog) if err != nil { return fmt.Errorf("step %d pgrb2b: %w", step, err) } prog.completeStep() return nil }) } if err := g.Wait(); err != nil { os.Remove(tempPath) return "", err } elapsed := time.Since(prog.startTime) totalMB := float64(prog.bytesDownloaded.Load()) / (1024 * 1024) d.log.Info("download complete, flushing to disk", zap.String("downloaded", fmt.Sprintf("%.1f MB", totalMB)), zap.Duration("elapsed", elapsed), zap.String("avg_speed", fmt.Sprintf("%.1f MB/s", totalMB/elapsed.Seconds()))) // Flush to disk if err := ds.Flush(); err != nil { os.Remove(tempPath) return "", fmt.Errorf("flush dataset: %w", err) } // Close before rename ds.Close() // Atomic rename if err := os.Rename(tempPath, finalPath); err != nil { os.Remove(tempPath) return "", fmt.Errorf("rename dataset: %w", err) } d.log.Info("dataset ready", zap.String("path", finalPath)) return finalPath, nil } // DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset. func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error { return d.downloadAndBlit(ctx, ds, baseURL, hourIdx, levelSet, nil) } // downloadAndBlit is the internal implementation with optional progress tracking. func (d *Downloader) downloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet, prog *progress) error { // 1. Download .idx idxURL := baseURL + ".idx" idxBody, err := d.httpGet(ctx, idxURL) if err != nil { return fmt.Errorf("download idx: %w", err) } // 2. Parse and filter entries := ParseIdx(idxBody) filtered := FilterIdx(entries, neededVariables) // Further filter to only levels in this level set var relevant []IdxEntry for _, e := range filtered { ls, ok := dataset.PressureLevelSet(e.LevelMB) if ok && ls == levelSet { relevant = append(relevant, e) } } if len(relevant) == 0 { d.log.Warn("no relevant entries found in idx", zap.String("url", idxURL), zap.Int("total_entries", len(entries)), zap.Int("filtered", len(filtered))) return nil } // 3. Download byte ranges and write to temp file ranges := EntriesToRanges(relevant) tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges, prog) if err != nil { return fmt.Errorf("download ranges: %w", err) } defer os.Remove(tmpFile) // 4. Read GRIB messages from temp file f, err := os.Open(tmpFile) if err != nil { return fmt.Errorf("open temp grib: %w", err) } messages, err := griblib.ReadMessages(f) f.Close() if err != nil { return fmt.Errorf("read grib messages: %w", err) } // 5. Decode and blit each message into the dataset for _, msg := range messages { if msg.Section4.ProductDefinitionTemplateNumber != 0 { continue } product := msg.Section4.ProductDefinitionTemplate varIdx := dataset.VariableIndex(int(product.ParameterCategory), int(product.ParameterNumber)) if varIdx < 0 { continue } if product.FirstSurface.Type != 100 { // isobaric surface continue } pressurePa := float64(product.FirstSurface.Value) pressureMB := int(math.Round(pressurePa / 100.0)) levelIdx := dataset.PressureIndex(pressureMB) if levelIdx < 0 { continue } data := msg.Data() if err := ds.BlitGribData(hourIdx, levelIdx, varIdx, data); err != nil { d.log.Warn("blit failed", zap.Int("var", varIdx), zap.Int("level_mb", pressureMB), zap.Error(err)) continue } } return nil } // downloadRangesToTempFile downloads multiple byte ranges from a URL, // concatenating them into a single temp file (valid concatenated GRIB messages). func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange, prog *progress) (string, error) { tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp") if err != nil { return "", fmt.Errorf("create temp file: %w", err) } tmpPath := tmpFile.Name() for _, r := range ranges { data, err := d.httpGetRange(ctx, baseURL, r.Start, r.End) if err != nil { tmpFile.Close() os.Remove(tmpPath) return "", fmt.Errorf("download range %d-%d: %w", r.Start, r.End, err) } if _, err := tmpFile.Write(data); err != nil { tmpFile.Close() os.Remove(tmpPath) return "", fmt.Errorf("write temp: %w", err) } if prog != nil { prog.addBytes(int64(len(data))) } } if err := tmpFile.Close(); err != nil { os.Remove(tmpPath) return "", err } return tmpPath, nil } // httpGet downloads a URL and returns the body bytes. func (d *Downloader) httpGet(ctx context.Context, url string) ([]byte, error) { var lastErr error for attempt := 0; attempt < 3; attempt++ { if attempt > 0 { select { case <-time.After(time.Duration(attempt*2) * time.Second): case <-ctx.Done(): return nil, ctx.Err() } } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } resp, err := d.client.Do(req) if err != nil { lastErr = err continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if resp.StatusCode != http.StatusOK { lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url) continue } if err != nil { lastErr = err continue } return body, nil } return nil, fmt.Errorf("after 3 attempts: %w", lastErr) } // httpGetRange downloads a byte range from a URL. func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64) ([]byte, error) { var lastErr error for attempt := 0; attempt < 3; attempt++ { if attempt > 0 { select { case <-time.After(time.Duration(attempt*2) * time.Second): case <-ctx.Done(): return nil, ctx.Err() } } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) resp, err := d.client.Do(req) if err != nil { lastErr = err continue } body, err := io.ReadAll(resp.Body) resp.Body.Close() if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url) continue } if err != nil { lastErr = err continue } return body, nil } return nil, fmt.Errorf("after 3 attempts: %w", lastErr) }