380 lines
9.4 KiB
Go
380 lines
9.4 KiB
Go
package downloader
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"math"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"time"
|
|
|
|
"predictor-refactored/internal/dataset"
|
|
|
|
"github.com/nilsmagnus/grib/griblib"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
// Downloader handles fetching GFS forecast data from S3 and assembling dataset files.
|
|
type Downloader struct {
|
|
cfg *Config
|
|
client *http.Client
|
|
log *zap.Logger
|
|
}
|
|
|
|
// NewDownloader creates a new Downloader.
|
|
func NewDownloader(cfg *Config, log *zap.Logger) *Downloader {
|
|
return &Downloader{
|
|
cfg: cfg,
|
|
client: &http.Client{
|
|
Timeout: 2 * time.Minute,
|
|
},
|
|
log: log,
|
|
}
|
|
}
|
|
|
|
// neededVariables is the set of GRIB variable names we need.
|
|
var neededVariables = map[string]bool{
|
|
"HGT": true,
|
|
"UGRD": true,
|
|
"VGRD": true,
|
|
}
|
|
|
|
// FindLatestRun finds the most recent available GFS model run on S3.
|
|
// It checks the last forecast step of each run to confirm availability.
|
|
func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) {
|
|
now := time.Now().UTC()
|
|
hour := now.Hour() - (now.Hour() % 6)
|
|
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
|
|
|
|
for i := 0; i < 8; i++ {
|
|
date := current.Format("20060102")
|
|
url := dataset.GribURL(date, current.Hour(), dataset.MaxHour) + ".idx"
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
|
|
if err != nil {
|
|
current = current.Add(-6 * time.Hour)
|
|
continue
|
|
}
|
|
|
|
resp, err := d.client.Do(req)
|
|
if err == nil {
|
|
resp.Body.Close()
|
|
if resp.StatusCode == http.StatusOK {
|
|
d.log.Info("found latest model run",
|
|
zap.Time("run", current),
|
|
zap.String("verified_url", url))
|
|
return current, nil
|
|
}
|
|
}
|
|
|
|
current = current.Add(-6 * time.Hour)
|
|
}
|
|
|
|
return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)")
|
|
}
|
|
|
|
// Download downloads a complete forecast and assembles a dataset file.
|
|
// Returns the path to the completed dataset file.
|
|
func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) {
|
|
date := run.Format("20060102")
|
|
runHour := run.Hour()
|
|
|
|
finalPath := filepath.Join(d.cfg.DataDir, run.Format("2006010215"))
|
|
tempPath := finalPath + ".downloading"
|
|
|
|
// Check if final dataset already exists
|
|
if info, err := os.Stat(finalPath); err == nil && info.Size() == dataset.DatasetSize {
|
|
d.log.Info("dataset already exists", zap.String("path", finalPath))
|
|
return finalPath, nil
|
|
}
|
|
|
|
d.log.Info("starting dataset download",
|
|
zap.Time("run", run),
|
|
zap.String("temp_path", tempPath))
|
|
|
|
// Create the dataset file
|
|
ds, err := dataset.Create(tempPath)
|
|
if err != nil {
|
|
return "", fmt.Errorf("create dataset: %w", err)
|
|
}
|
|
defer ds.Close()
|
|
|
|
steps := dataset.Hours()
|
|
totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step
|
|
completed := 0
|
|
|
|
// Process each forecast step with bounded concurrency
|
|
g, ctx := errgroup.WithContext(ctx)
|
|
sem := make(chan struct{}, d.cfg.Parallel)
|
|
|
|
for _, step := range steps {
|
|
step := step
|
|
hourIdx := dataset.HourIndex(step)
|
|
if hourIdx < 0 {
|
|
continue
|
|
}
|
|
|
|
// Download pgrb2 (level set A)
|
|
sem <- struct{}{}
|
|
g.Go(func() error {
|
|
defer func() { <-sem }()
|
|
url := dataset.GribURL(date, runHour, step)
|
|
err := d.DownloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA)
|
|
if err != nil {
|
|
return fmt.Errorf("step %d pgrb2: %w", step, err)
|
|
}
|
|
completed++
|
|
d.log.Debug("step complete",
|
|
zap.Int("step", step),
|
|
zap.String("set", "pgrb2"),
|
|
zap.Int("progress", completed),
|
|
zap.Int("total", totalSteps))
|
|
return nil
|
|
})
|
|
|
|
// Download pgrb2b (level set B)
|
|
sem <- struct{}{}
|
|
g.Go(func() error {
|
|
defer func() { <-sem }()
|
|
url := dataset.GribURLB(date, runHour, step)
|
|
err := d.DownloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB)
|
|
if err != nil {
|
|
return fmt.Errorf("step %d pgrb2b: %w", step, err)
|
|
}
|
|
completed++
|
|
d.log.Debug("step complete",
|
|
zap.Int("step", step),
|
|
zap.String("set", "pgrb2b"),
|
|
zap.Int("progress", completed),
|
|
zap.Int("total", totalSteps))
|
|
return nil
|
|
})
|
|
}
|
|
|
|
if err := g.Wait(); err != nil {
|
|
os.Remove(tempPath)
|
|
return "", err
|
|
}
|
|
|
|
// Flush to disk
|
|
if err := ds.Flush(); err != nil {
|
|
os.Remove(tempPath)
|
|
return "", fmt.Errorf("flush dataset: %w", err)
|
|
}
|
|
|
|
// Close before rename
|
|
ds.Close()
|
|
|
|
// Atomic rename
|
|
if err := os.Rename(tempPath, finalPath); err != nil {
|
|
os.Remove(tempPath)
|
|
return "", fmt.Errorf("rename dataset: %w", err)
|
|
}
|
|
|
|
d.log.Info("dataset download complete", zap.String("path", finalPath))
|
|
return finalPath, nil
|
|
}
|
|
|
|
// DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset.
|
|
func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error {
|
|
// 1. Download .idx
|
|
idxURL := baseURL + ".idx"
|
|
idxBody, err := d.httpGet(ctx, idxURL)
|
|
if err != nil {
|
|
return fmt.Errorf("download idx: %w", err)
|
|
}
|
|
|
|
// 2. Parse and filter
|
|
entries := ParseIdx(idxBody)
|
|
filtered := FilterIdx(entries, neededVariables)
|
|
|
|
// Further filter to only levels in this level set
|
|
var relevant []IdxEntry
|
|
for _, e := range filtered {
|
|
ls, ok := dataset.PressureLevelSet(e.LevelMB)
|
|
if ok && ls == levelSet {
|
|
relevant = append(relevant, e)
|
|
}
|
|
}
|
|
|
|
if len(relevant) == 0 {
|
|
d.log.Warn("no relevant entries found in idx",
|
|
zap.String("url", idxURL),
|
|
zap.Int("total_entries", len(entries)),
|
|
zap.Int("filtered", len(filtered)))
|
|
return nil
|
|
}
|
|
|
|
// 3. Download byte ranges and write to temp file
|
|
ranges := EntriesToRanges(relevant)
|
|
tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges)
|
|
if err != nil {
|
|
return fmt.Errorf("download ranges: %w", err)
|
|
}
|
|
defer os.Remove(tmpFile)
|
|
|
|
// 4. Read GRIB messages from temp file
|
|
f, err := os.Open(tmpFile)
|
|
if err != nil {
|
|
return fmt.Errorf("open temp grib: %w", err)
|
|
}
|
|
|
|
messages, err := griblib.ReadMessages(f)
|
|
f.Close()
|
|
if err != nil {
|
|
return fmt.Errorf("read grib messages: %w", err)
|
|
}
|
|
|
|
// 5. Decode and blit each message into the dataset
|
|
for _, msg := range messages {
|
|
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
|
|
continue
|
|
}
|
|
|
|
product := msg.Section4.ProductDefinitionTemplate
|
|
|
|
varIdx := dataset.VariableIndex(int(product.ParameterCategory), int(product.ParameterNumber))
|
|
if varIdx < 0 {
|
|
continue
|
|
}
|
|
|
|
if product.FirstSurface.Type != 100 { // isobaric surface
|
|
continue
|
|
}
|
|
|
|
pressurePa := float64(product.FirstSurface.Value)
|
|
pressureMB := int(math.Round(pressurePa / 100.0))
|
|
levelIdx := dataset.PressureIndex(pressureMB)
|
|
if levelIdx < 0 {
|
|
continue
|
|
}
|
|
|
|
data := msg.Data()
|
|
if err := ds.BlitGribData(hourIdx, levelIdx, varIdx, data); err != nil {
|
|
d.log.Warn("blit failed",
|
|
zap.Int("var", varIdx),
|
|
zap.Int("level_mb", pressureMB),
|
|
zap.Error(err))
|
|
continue
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// downloadRangesToTempFile downloads multiple byte ranges from a URL,
|
|
// concatenating them into a single temp file (valid concatenated GRIB messages).
|
|
func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange) (string, error) {
|
|
tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp")
|
|
if err != nil {
|
|
return "", fmt.Errorf("create temp file: %w", err)
|
|
}
|
|
tmpPath := tmpFile.Name()
|
|
|
|
for _, r := range ranges {
|
|
data, err := d.httpGetRange(ctx, baseURL, r.Start, r.End)
|
|
if err != nil {
|
|
tmpFile.Close()
|
|
os.Remove(tmpPath)
|
|
return "", fmt.Errorf("download range %d-%d: %w", r.Start, r.End, err)
|
|
}
|
|
if _, err := tmpFile.Write(data); err != nil {
|
|
tmpFile.Close()
|
|
os.Remove(tmpPath)
|
|
return "", fmt.Errorf("write temp: %w", err)
|
|
}
|
|
}
|
|
|
|
if err := tmpFile.Close(); err != nil {
|
|
os.Remove(tmpPath)
|
|
return "", err
|
|
}
|
|
|
|
return tmpPath, nil
|
|
}
|
|
|
|
// httpGet downloads a URL and returns the body bytes.
|
|
func (d *Downloader) httpGet(ctx context.Context, url string) ([]byte, error) {
|
|
var lastErr error
|
|
for attempt := 0; attempt < 3; attempt++ {
|
|
if attempt > 0 {
|
|
select {
|
|
case <-time.After(time.Duration(attempt*2) * time.Second):
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
resp, err := d.client.Do(req)
|
|
if err != nil {
|
|
lastErr = err
|
|
continue
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
|
|
continue
|
|
}
|
|
if err != nil {
|
|
lastErr = err
|
|
continue
|
|
}
|
|
|
|
return body, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
|
|
}
|
|
|
|
// httpGetRange downloads a byte range from a URL.
|
|
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64) ([]byte, error) {
|
|
var lastErr error
|
|
for attempt := 0; attempt < 3; attempt++ {
|
|
if attempt > 0 {
|
|
select {
|
|
case <-time.After(time.Duration(attempt*2) * time.Second):
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
|
|
|
|
resp, err := d.client.Do(req)
|
|
if err != nil {
|
|
lastErr = err
|
|
continue
|
|
}
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
|
|
lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url)
|
|
continue
|
|
}
|
|
if err != nil {
|
|
lastErr = err
|
|
continue
|
|
}
|
|
|
|
return body, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
|
|
}
|