engine refactor
This commit is contained in:
parent
9e663db9dc
commit
81b8e763bd
37 changed files with 3532 additions and 1639 deletions
369
internal/datasets/grib/downloader.go
Normal file
369
internal/datasets/grib/downloader.go
Normal file
|
|
@ -0,0 +1,369 @@
|
|||
package grib
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/nilsmagnus/grib/griblib"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"predictor-refactored/internal/datasets"
|
||||
wgfs "predictor-refactored/internal/weather/gfs"
|
||||
)
|
||||
|
||||
// URLFunc returns the GRIB URL for one (date, runHour, member, step, levelSet).
|
||||
// Sources that don't have members (GFS) ignore the member argument.
|
||||
type URLFunc func(date string, runHour, member, step int, ls wgfs.LevelSet) string
|
||||
|
||||
// Downloader is the generic GRIB-cube downloader.
|
||||
//
|
||||
// A Source plugs in its variant, URL templating, and member-resolution
|
||||
// logic; the Downloader runs the parallel idx fetch, byte-range download,
|
||||
// GRIB decode, and blit loop with manifest-based resume.
|
||||
type Downloader struct {
|
||||
Variant *wgfs.Variant
|
||||
URLs URLFunc
|
||||
Parallel int
|
||||
Client *http.Client
|
||||
Log *zap.Logger
|
||||
}
|
||||
|
||||
func (d *Downloader) log() *zap.Logger {
|
||||
if d.Log == nil {
|
||||
return zap.NewNop()
|
||||
}
|
||||
return d.Log
|
||||
}
|
||||
|
||||
func (d *Downloader) client() *http.Client {
|
||||
if d.Client == nil {
|
||||
return &http.Client{Timeout: 2 * time.Minute}
|
||||
}
|
||||
return d.Client
|
||||
}
|
||||
|
||||
func (d *Downloader) parallel() int {
|
||||
if d.Parallel <= 0 {
|
||||
return 8
|
||||
}
|
||||
return d.Parallel
|
||||
}
|
||||
|
||||
// neededVariables is the GRIB variable set every source extracts.
|
||||
var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true}
|
||||
|
||||
// Run downloads the dataset for id, member into store. The caller may
|
||||
// pass member=0 for non-ensemble sources.
|
||||
func (d *Downloader) Run(ctx context.Context, id datasets.DatasetID, member int, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error {
|
||||
if prog == nil {
|
||||
prog = noopSink{}
|
||||
}
|
||||
|
||||
handle, err := store.BeginWrite(id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin write: %w", err)
|
||||
}
|
||||
manifest := handle.Manifest()
|
||||
|
||||
file, err := openOrCreateCube(handle.Path(), d.Variant)
|
||||
if err != nil {
|
||||
_ = handle.Abort()
|
||||
return err
|
||||
}
|
||||
|
||||
epoch := id.Epoch.UTC()
|
||||
date := epoch.Format("20060102")
|
||||
runHour := epoch.Hour()
|
||||
|
||||
steps := d.Variant.Hours()
|
||||
if hr := id.Subset.HourRange; hr != nil {
|
||||
filtered := steps[:0]
|
||||
for _, step := range steps {
|
||||
if step >= hr.MinHour && step <= hr.MaxHour {
|
||||
filtered = append(filtered, step)
|
||||
}
|
||||
}
|
||||
steps = filtered
|
||||
}
|
||||
prog.SetTotal(len(steps) * 2)
|
||||
for range manifest.Units() {
|
||||
prog.StepComplete()
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.SetLimit(d.parallel())
|
||||
var fileMu sync.Mutex
|
||||
|
||||
for _, step := range steps {
|
||||
hourIdx := d.Variant.HourIndex(step)
|
||||
if hourIdx < 0 {
|
||||
continue
|
||||
}
|
||||
for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} {
|
||||
unit := unitKey(step, ls)
|
||||
if manifest.Has(unit) {
|
||||
continue
|
||||
}
|
||||
g.Go(func() error {
|
||||
url := d.URLs(date, runHour, member, step, ls)
|
||||
if err := d.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil {
|
||||
return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err)
|
||||
}
|
||||
if err := manifest.Mark(unit); err != nil {
|
||||
return fmt.Errorf("mark unit: %w", err)
|
||||
}
|
||||
prog.StepComplete()
|
||||
return nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
_ = file.Close()
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return err
|
||||
}
|
||||
if len(manifest.Units()) == 0 {
|
||||
_ = handle.Abort()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if err := file.Flush(); err != nil {
|
||||
_ = file.Close()
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
if err := file.Close(); err != nil {
|
||||
return fmt.Errorf("close: %w", err)
|
||||
}
|
||||
if err := handle.Commit(); err != nil {
|
||||
return fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
d.log().Info("download complete",
|
||||
zap.String("variant", d.Variant.ID),
|
||||
zap.Time("epoch", epoch),
|
||||
zap.Duration("elapsed", time.Since(start)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// openOrCreateCube opens an existing cube at path if it matches variant's
|
||||
// expected size, else truncate-creates a new one.
|
||||
func openOrCreateCube(path string, variant *wgfs.Variant) (*wgfs.File, error) {
|
||||
info, err := os.Stat(path)
|
||||
if err == nil && info.Size() == variant.DatasetSize() {
|
||||
return wgfs.OpenWritable(path, variant)
|
||||
}
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("stat cube: %w", err)
|
||||
}
|
||||
return wgfs.Create(path, variant)
|
||||
}
|
||||
|
||||
// downloadAndBlit fetches and decodes one (URL, level-set) chunk.
|
||||
func (d *Downloader) downloadAndBlit(
|
||||
ctx context.Context,
|
||||
file *wgfs.File,
|
||||
fileMu *sync.Mutex,
|
||||
baseURL string,
|
||||
hourIdx int,
|
||||
ls wgfs.LevelSet,
|
||||
prog datasets.ProgressSink,
|
||||
throttle datasets.Throttle,
|
||||
) error {
|
||||
idxBody, err := d.httpGet(ctx, baseURL+".idx", throttle, prog)
|
||||
if err != nil {
|
||||
return fmt.Errorf("idx: %w", err)
|
||||
}
|
||||
entries := ParseIdx(idxBody)
|
||||
filtered := FilterIdx(entries, neededVariables)
|
||||
|
||||
var relevant []IdxEntry
|
||||
for _, e := range filtered {
|
||||
set, ok := d.Variant.PressureLevelSet(e.LevelMB)
|
||||
if ok && set == ls {
|
||||
relevant = append(relevant, e)
|
||||
}
|
||||
}
|
||||
if len(relevant) == 0 {
|
||||
return nil
|
||||
}
|
||||
ranges := EntriesToRanges(relevant)
|
||||
|
||||
tmp, err := os.CreateTemp("", "grib-msg-*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("temp: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
for _, r := range ranges {
|
||||
body, err := d.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog)
|
||||
if err != nil {
|
||||
tmp.Close()
|
||||
return fmt.Errorf("range: %w", err)
|
||||
}
|
||||
if _, err := tmp.Write(body); err != nil {
|
||||
tmp.Close()
|
||||
return fmt.Errorf("write tmp: %w", err)
|
||||
}
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Open(tmpPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
messages, err := griblib.ReadMessages(f)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read grib: %w", err)
|
||||
}
|
||||
|
||||
for _, msg := range messages {
|
||||
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
|
||||
continue
|
||||
}
|
||||
p := msg.Section4.ProductDefinitionTemplate
|
||||
varIdx := d.Variant.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber))
|
||||
if varIdx < 0 {
|
||||
continue
|
||||
}
|
||||
if p.FirstSurface.Type != 100 {
|
||||
continue
|
||||
}
|
||||
pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0))
|
||||
levelIdx := d.Variant.PressureIndex(pressureMB)
|
||||
if levelIdx < 0 {
|
||||
continue
|
||||
}
|
||||
data := msg.Data()
|
||||
fileMu.Lock()
|
||||
err := file.BlitGribData(hourIdx, levelIdx, varIdx, data)
|
||||
fileMu.Unlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("blit: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *Downloader) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
|
||||
var lastErr error
|
||||
for attempt := range 3 {
|
||||
if attempt > 0 {
|
||||
select {
|
||||
case <-time.After(time.Duration(attempt*2) * time.Second):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp, err := d.client().Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
body, err := readThrottled(ctx, resp.Body, throttle, prog)
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
|
||||
}
|
||||
|
||||
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
|
||||
var lastErr error
|
||||
for attempt := range 3 {
|
||||
if attempt > 0 {
|
||||
select {
|
||||
case <-time.After(time.Duration(attempt*2) * time.Second):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
|
||||
resp, err := d.client().Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
body, err := readThrottled(ctx, resp.Body, throttle, prog)
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("HTTP %d for range", resp.StatusCode)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
|
||||
}
|
||||
|
||||
func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
|
||||
buf := make([]byte, 0, 64*1024)
|
||||
chunk := make([]byte, 32*1024)
|
||||
for {
|
||||
if throttle != nil {
|
||||
if err := throttle.Wait(ctx, len(chunk)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
n, err := r.Read(chunk)
|
||||
if n > 0 {
|
||||
buf = append(buf, chunk[:n]...)
|
||||
prog.Bytes(int64(n))
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return buf, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func unitKey(step int, ls wgfs.LevelSet) string {
|
||||
return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls))
|
||||
}
|
||||
|
||||
func levelSetLabel(ls wgfs.LevelSet) string {
|
||||
if ls == wgfs.LevelSetB {
|
||||
return "B"
|
||||
}
|
||||
return "A"
|
||||
}
|
||||
|
||||
type noopSink struct{}
|
||||
|
||||
func (noopSink) SetTotal(int) {}
|
||||
func (noopSink) StepComplete() {}
|
||||
func (noopSink) Bytes(int64) {}
|
||||
Loading…
Add table
Add a link
Reference in a new issue