engine refactor

This commit is contained in:
Anatoly Antonov 2026-05-23 00:55:35 +09:00
parent 9e663db9dc
commit 81b8e763bd
37 changed files with 3532 additions and 1639 deletions

View file

@ -0,0 +1,369 @@
package grib
import (
"context"
"errors"
"fmt"
"io"
"math"
"net/http"
"os"
"sync"
"time"
"github.com/nilsmagnus/grib/griblib"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"predictor-refactored/internal/datasets"
wgfs "predictor-refactored/internal/weather/gfs"
)
// URLFunc returns the GRIB URL for one (date, runHour, member, step, levelSet).
// Sources that don't have members (GFS) ignore the member argument.
type URLFunc func(date string, runHour, member, step int, ls wgfs.LevelSet) string
// Downloader is the generic GRIB-cube downloader.
//
// A Source plugs in its variant, URL templating, and member-resolution
// logic; the Downloader runs the parallel idx fetch, byte-range download,
// GRIB decode, and blit loop with manifest-based resume.
type Downloader struct {
Variant *wgfs.Variant
URLs URLFunc
Parallel int
Client *http.Client
Log *zap.Logger
}
func (d *Downloader) log() *zap.Logger {
if d.Log == nil {
return zap.NewNop()
}
return d.Log
}
func (d *Downloader) client() *http.Client {
if d.Client == nil {
return &http.Client{Timeout: 2 * time.Minute}
}
return d.Client
}
func (d *Downloader) parallel() int {
if d.Parallel <= 0 {
return 8
}
return d.Parallel
}
// neededVariables is the GRIB variable set every source extracts.
var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true}
// Run downloads the dataset for id, member into store. The caller may
// pass member=0 for non-ensemble sources.
func (d *Downloader) Run(ctx context.Context, id datasets.DatasetID, member int, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error {
if prog == nil {
prog = noopSink{}
}
handle, err := store.BeginWrite(id)
if err != nil {
return fmt.Errorf("begin write: %w", err)
}
manifest := handle.Manifest()
file, err := openOrCreateCube(handle.Path(), d.Variant)
if err != nil {
_ = handle.Abort()
return err
}
epoch := id.Epoch.UTC()
date := epoch.Format("20060102")
runHour := epoch.Hour()
steps := d.Variant.Hours()
if hr := id.Subset.HourRange; hr != nil {
filtered := steps[:0]
for _, step := range steps {
if step >= hr.MinHour && step <= hr.MaxHour {
filtered = append(filtered, step)
}
}
steps = filtered
}
prog.SetTotal(len(steps) * 2)
for range manifest.Units() {
prog.StepComplete()
}
start := time.Now()
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(d.parallel())
var fileMu sync.Mutex
for _, step := range steps {
hourIdx := d.Variant.HourIndex(step)
if hourIdx < 0 {
continue
}
for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} {
unit := unitKey(step, ls)
if manifest.Has(unit) {
continue
}
g.Go(func() error {
url := d.URLs(date, runHour, member, step, ls)
if err := d.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil {
return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err)
}
if err := manifest.Mark(unit); err != nil {
return fmt.Errorf("mark unit: %w", err)
}
prog.StepComplete()
return nil
})
}
}
if err := g.Wait(); err != nil {
_ = file.Close()
if errors.Is(err, context.Canceled) {
return err
}
if len(manifest.Units()) == 0 {
_ = handle.Abort()
}
return err
}
if err := file.Flush(); err != nil {
_ = file.Close()
return fmt.Errorf("flush: %w", err)
}
if err := file.Close(); err != nil {
return fmt.Errorf("close: %w", err)
}
if err := handle.Commit(); err != nil {
return fmt.Errorf("commit: %w", err)
}
d.log().Info("download complete",
zap.String("variant", d.Variant.ID),
zap.Time("epoch", epoch),
zap.Duration("elapsed", time.Since(start)))
return nil
}
// openOrCreateCube opens an existing cube at path if it matches variant's
// expected size, else truncate-creates a new one.
func openOrCreateCube(path string, variant *wgfs.Variant) (*wgfs.File, error) {
info, err := os.Stat(path)
if err == nil && info.Size() == variant.DatasetSize() {
return wgfs.OpenWritable(path, variant)
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat cube: %w", err)
}
return wgfs.Create(path, variant)
}
// downloadAndBlit fetches and decodes one (URL, level-set) chunk.
func (d *Downloader) downloadAndBlit(
ctx context.Context,
file *wgfs.File,
fileMu *sync.Mutex,
baseURL string,
hourIdx int,
ls wgfs.LevelSet,
prog datasets.ProgressSink,
throttle datasets.Throttle,
) error {
idxBody, err := d.httpGet(ctx, baseURL+".idx", throttle, prog)
if err != nil {
return fmt.Errorf("idx: %w", err)
}
entries := ParseIdx(idxBody)
filtered := FilterIdx(entries, neededVariables)
var relevant []IdxEntry
for _, e := range filtered {
set, ok := d.Variant.PressureLevelSet(e.LevelMB)
if ok && set == ls {
relevant = append(relevant, e)
}
}
if len(relevant) == 0 {
return nil
}
ranges := EntriesToRanges(relevant)
tmp, err := os.CreateTemp("", "grib-msg-*.tmp")
if err != nil {
return fmt.Errorf("temp: %w", err)
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
for _, r := range ranges {
body, err := d.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog)
if err != nil {
tmp.Close()
return fmt.Errorf("range: %w", err)
}
if _, err := tmp.Write(body); err != nil {
tmp.Close()
return fmt.Errorf("write tmp: %w", err)
}
}
if err := tmp.Close(); err != nil {
return err
}
f, err := os.Open(tmpPath)
if err != nil {
return err
}
messages, err := griblib.ReadMessages(f)
f.Close()
if err != nil {
return fmt.Errorf("read grib: %w", err)
}
for _, msg := range messages {
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
p := msg.Section4.ProductDefinitionTemplate
varIdx := d.Variant.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber))
if varIdx < 0 {
continue
}
if p.FirstSurface.Type != 100 {
continue
}
pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0))
levelIdx := d.Variant.PressureIndex(pressureMB)
if levelIdx < 0 {
continue
}
data := msg.Data()
fileMu.Lock()
err := file.BlitGribData(hourIdx, levelIdx, varIdx, data)
fileMu.Unlock()
if err != nil {
return fmt.Errorf("blit: %w", err)
}
}
return nil
}
func (d *Downloader) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
var lastErr error
for attempt := range 3 {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := d.client().Do(req)
if err != nil {
lastErr = err
continue
}
body, err := readThrottled(ctx, resp.Body, throttle, prog)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
var lastErr error
for attempt := range 3 {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := d.client().Do(req)
if err != nil {
lastErr = err
continue
}
body, err := readThrottled(ctx, resp.Body, throttle, prog)
resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for range", resp.StatusCode)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
buf := make([]byte, 0, 64*1024)
chunk := make([]byte, 32*1024)
for {
if throttle != nil {
if err := throttle.Wait(ctx, len(chunk)); err != nil {
return nil, err
}
}
n, err := r.Read(chunk)
if n > 0 {
buf = append(buf, chunk[:n]...)
prog.Bytes(int64(n))
}
if errors.Is(err, io.EOF) {
return buf, nil
}
if err != nil {
return nil, err
}
}
}
func unitKey(step int, ls wgfs.LevelSet) string {
return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls))
}
func levelSetLabel(ls wgfs.LevelSet) string {
if ls == wgfs.LevelSetB {
return "B"
}
return "A"
}
type noopSink struct{}
func (noopSink) SetTotal(int) {}
func (noopSink) StepComplete() {}
func (noopSink) Bytes(int64) {}

View file

@ -0,0 +1,129 @@
// Package grib contains the GRIB-cube download skeleton shared by every
// NOAA source (GFS, GEFS, future families). It exposes the .idx parser,
// HTTP helpers, and a parallel download loop; source-specific URL
// templating is injected by the caller.
package grib
import (
"fmt"
"strconv"
"strings"
)
// IdxEntry is one parsed line from a NOAA GRIB .idx file.
//
// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:"
type IdxEntry struct {
Index int
Offset int64
Variable string
LevelMB int // 0 when the level is not isobaric
Hour int // forecast hour; 0 for analysis ("anl"); -1 if unparseable
EndOffset int64 // computed from the next entry's Offset; -1 for the final entry
}
// Length returns the byte length of this GRIB message, or -1 if unknown
// (the final entry in an idx file).
func (e *IdxEntry) Length() int64 {
if e.EndOffset <= 0 {
return -1
}
return e.EndOffset - e.Offset
}
// ParseIdx parses a .idx file body. Unparseable lines are silently skipped.
func ParseIdx(body []byte) []IdxEntry {
lines := strings.Split(string(body), "\n")
var entries []IdxEntry
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.Split(line, ":")
if len(parts) < 7 {
continue
}
idx, err := strconv.Atoi(parts[0])
if err != nil {
continue
}
off, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
continue
}
entries = append(entries, IdxEntry{
Index: idx,
Offset: off,
Variable: parts[3],
LevelMB: parseLevelMB(parts[4]),
Hour: parseHour(parts[5]),
EndOffset: -1,
})
}
for i := 0; i < len(entries)-1; i++ {
entries[i].EndOffset = entries[i+1].Offset
}
return entries
}
// FilterIdx returns entries matching one of the wanted variables at a known
// pressure level with a computable byte length.
func FilterIdx(entries []IdxEntry, wanted map[string]bool) []IdxEntry {
var out []IdxEntry
for _, e := range entries {
if !wanted[e.Variable] || e.LevelMB <= 0 || e.Length() <= 0 {
continue
}
out = append(out, e)
}
return out
}
func parseLevelMB(s string) int {
s = strings.TrimSpace(s)
if !strings.HasSuffix(s, " mb") {
return 0
}
n, err := strconv.Atoi(strings.TrimSuffix(s, " mb"))
if err != nil {
return 0
}
return n
}
func parseHour(s string) int {
s = strings.TrimSpace(s)
if s == "anl" {
return 0
}
n, err := strconv.Atoi(strings.TrimSuffix(s, " hour fcst"))
if err != nil {
return -1
}
return n
}
// ByteRange is one HTTP range download corresponding to one GRIB message.
type ByteRange struct {
Start int64
End int64 // inclusive
Entry IdxEntry
}
// EntriesToRanges converts idx entries to inclusive HTTP byte ranges.
func EntriesToRanges(entries []IdxEntry) []ByteRange {
out := make([]ByteRange, 0, len(entries))
for _, e := range entries {
if e.Length() <= 0 {
continue
}
out = append(out, ByteRange{Start: e.Offset, End: e.EndOffset - 1, Entry: e})
}
return out
}
// FormatRange returns an HTTP Range header value for the byte range.
func (r ByteRange) FormatRange() string {
return fmt.Sprintf("bytes=%d-%d", r.Start, r.End)
}

View file

@ -0,0 +1,70 @@
package grib
import "testing"
const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst:
2:289012:d=2024010100:HGT:975 mb:0 hour fcst:
3:541876:d=2024010100:TMP:1000 mb:0 hour fcst:
4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst:
5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst:
6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst:
7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst:
8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst:
9:2098765:d=2024010100:HGT:500 mb:3 hour fcst:
`
func TestParseIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
if len(entries) != 9 {
t.Fatalf("expected 9 entries, got %d", len(entries))
}
if e := entries[0]; e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 || e.EndOffset != 289012 {
t.Errorf("entry 0: %+v", e)
}
if e := entries[6]; e.LevelMB != 0 {
t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB)
}
if e := entries[len(entries)-1]; e.EndOffset != -1 {
t.Errorf("last entry EndOffset: got %d, want -1", e.EndOffset)
}
}
func TestFilterIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
want := map[string]bool{"HGT": true, "UGRD": true, "VGRD": true}
filtered := FilterIdx(entries, want)
// HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6
// HGT@500 at 3hr is last entry (no EndOffset), so dropped.
if len(filtered) != 6 {
t.Errorf("expected 6, got %d", len(filtered))
}
}
func TestParseLevelMB(t *testing.T) {
cases := []struct {
in string
want int
}{
{"1000 mb", 1000}, {"975 mb", 975}, {"1 mb", 1},
{"2 m above ground", 0}, {"surface", 0}, {"tropopause", 0},
}
for _, c := range cases {
if got := parseLevelMB(c.in); got != c.want {
t.Errorf("parseLevelMB(%q) = %d, want %d", c.in, got, c.want)
}
}
}
func TestParseHour(t *testing.T) {
cases := []struct {
in string
want int
}{
{"0 hour fcst", 0}, {"3 hour fcst", 3}, {"192 hour fcst", 192}, {"anl", 0},
}
for _, c := range cases {
if got := parseHour(c.in); got != c.want {
t.Errorf("parseHour(%q) = %d, want %d", c.in, got, c.want)
}
}
}