Refactor #7

Open
a.antonov wants to merge 8 commits from afanasyev.aa/predictor:refactor into main
Showing only changes of commit e14a00287c - Show all commits

View file

@ -8,6 +8,7 @@ import (
"net/http"
"os"
"path/filepath"
"sync/atomic"
"time"
"predictor-refactored/internal/dataset"
@ -76,6 +77,60 @@ func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) {
return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)")
}
// progress tracks download progress across concurrent goroutines.
type progress struct {
bytesDownloaded atomic.Int64
stepsCompleted atomic.Int64
totalSteps int64
startTime time.Time
log *zap.Logger
}
func newProgress(totalSteps int, log *zap.Logger) *progress {
return &progress{
totalSteps: int64(totalSteps),
startTime: time.Now(),
log: log,
}
}
func (p *progress) addBytes(n int64) {
p.bytesDownloaded.Add(n)
}
func (p *progress) completeStep() {
done := p.stepsCompleted.Add(1)
total := p.totalSteps
bytes := p.bytesDownloaded.Load()
elapsed := time.Since(p.startTime).Seconds()
pct := float64(done) / float64(total) * 100
mbDownloaded := float64(bytes) / (1024 * 1024)
mbPerSec := 0.0
if elapsed > 0 {
mbPerSec = mbDownloaded / elapsed
}
// Estimate remaining
eta := ""
if done > 0 && done < total {
secsPerStep := elapsed / float64(done)
remaining := secsPerStep * float64(total-done)
if remaining > 60 {
eta = fmt.Sprintf("%.0fm%02.0fs", math.Floor(remaining/60), math.Mod(remaining, 60))
} else {
eta = fmt.Sprintf("%.0fs", remaining)
}
}
p.log.Info("download progress",
zap.String("progress", fmt.Sprintf("%d/%d", done, total)),
zap.String("percent", fmt.Sprintf("%.1f%%", pct)),
zap.String("downloaded", fmt.Sprintf("%.1f MB", mbDownloaded)),
zap.String("speed", fmt.Sprintf("%.1f MB/s", mbPerSec)),
zap.String("eta", eta))
}
// Download downloads a complete forecast and assembles a dataset file.
// Returns the path to the completed dataset file.
func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) {
@ -91,8 +146,13 @@ func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error
return finalPath, nil
}
steps := dataset.Hours()
totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step
prog := newProgress(totalSteps, d.log)
d.log.Info("starting dataset download",
zap.Time("run", run),
zap.Int("total_steps", totalSteps),
zap.String("temp_path", tempPath))
// Create the dataset file
@ -102,10 +162,6 @@ func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error
}
defer ds.Close()
steps := dataset.Hours()
totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step
completed := 0
// Process each forecast step with bounded concurrency
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.cfg.Parallel)
@ -122,16 +178,11 @@ func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURL(date, runHour, step)
err := d.DownloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA)
err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA, prog)
if err != nil {
return fmt.Errorf("step %d pgrb2: %w", step, err)
}
completed++
d.log.Debug("step complete",
zap.Int("step", step),
zap.String("set", "pgrb2"),
zap.Int("progress", completed),
zap.Int("total", totalSteps))
prog.completeStep()
return nil
})
@ -140,16 +191,11 @@ func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURLB(date, runHour, step)
err := d.DownloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB)
err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB, prog)
if err != nil {
return fmt.Errorf("step %d pgrb2b: %w", step, err)
}
completed++
d.log.Debug("step complete",
zap.Int("step", step),
zap.String("set", "pgrb2b"),
zap.Int("progress", completed),
zap.Int("total", totalSteps))
prog.completeStep()
return nil
})
}
@ -159,6 +205,13 @@ func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error
return "", err
}
elapsed := time.Since(prog.startTime)
totalMB := float64(prog.bytesDownloaded.Load()) / (1024 * 1024)
d.log.Info("download complete, flushing to disk",
zap.String("downloaded", fmt.Sprintf("%.1f MB", totalMB)),
zap.Duration("elapsed", elapsed),
zap.String("avg_speed", fmt.Sprintf("%.1f MB/s", totalMB/elapsed.Seconds())))
// Flush to disk
if err := ds.Flush(); err != nil {
os.Remove(tempPath)
@ -174,12 +227,17 @@ func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error
return "", fmt.Errorf("rename dataset: %w", err)
}
d.log.Info("dataset download complete", zap.String("path", finalPath))
d.log.Info("dataset ready", zap.String("path", finalPath))
return finalPath, nil
}
// DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset.
func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error {
return d.downloadAndBlit(ctx, ds, baseURL, hourIdx, levelSet, nil)
}
// downloadAndBlit is the internal implementation with optional progress tracking.
func (d *Downloader) downloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet, prog *progress) error {
// 1. Download .idx
idxURL := baseURL + ".idx"
idxBody, err := d.httpGet(ctx, idxURL)
@ -210,7 +268,7 @@ func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, base
// 3. Download byte ranges and write to temp file
ranges := EntriesToRanges(relevant)
tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges)
tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges, prog)
if err != nil {
return fmt.Errorf("download ranges: %w", err)
}
@ -267,7 +325,7 @@ func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, base
// downloadRangesToTempFile downloads multiple byte ranges from a URL,
// concatenating them into a single temp file (valid concatenated GRIB messages).
func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange) (string, error) {
func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange, prog *progress) (string, error) {
tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp")
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
@ -286,6 +344,9 @@ func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL strin
os.Remove(tmpPath)
return "", fmt.Errorf("write temp: %w", err)
}
if prog != nil {
prog.addBytes(int64(len(data)))
}
}
if err := tmpFile.Close(); err != nil {