predictor/internal/pkg/grib/partial_downloader.go
2026-03-22 16:29:53 +09:00

350 lines
8.3 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package grib
import (
"bufio"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
)
// PartialDownloader загружает только необходимые поля из GRIB файлов
// используя HTTP Range requests и .idx индексные файлы
type PartialDownloader struct {
Dir string
Parallel int
Client *http.Client
Variables []string
ds *DatasetConfig
}
// NewPartialDownloader создаёт новый partial downloader
func NewPartialDownloader(dir string, parallel int, dc *DatasetConfig) *PartialDownloader {
return &PartialDownloader{
Dir: dir,
Parallel: parallel,
Client: &http.Client{
Timeout: 60 * time.Second,
},
Variables: dc.Variables,
ds: dc,
}
}
// idxEntry представляет запись из .idx файла
type idxEntry struct {
Index int
ByteStart int64
Date string
Variable string
Level string
Forecast string
}
type ProgressWriter struct {
Total int64
Downloaded int64
OnProgress func(percent float64)
}
func (pw *ProgressWriter) Write(p []byte) (int, error) {
n := len(p)
pw.Downloaded += int64(n)
if pw.Total > 0 && pw.OnProgress != nil {
percent := float64(pw.Downloaded) / float64(pw.Total) * 100
pw.OnProgress(percent)
}
return n, nil
}
// parseIdx парсит .idx файл и возвращает записи
func (d *PartialDownloader) parseIdx(body []byte) []idxEntry {
var entries []idxEntry
lines := strings.Split(string(body), "\n")
for _, line := range lines {
if line == "" {
continue
}
parts := strings.Split(line, ":")
if len(parts) < 7 {
continue
}
byteStart, _ := strconv.ParseInt(parts[1], 10, 64)
entries = append(entries, idxEntry{
Index: len(entries),
ByteStart: byteStart,
Date: parts[2],
Variable: parts[3],
Level: parts[4],
Forecast: parts[5],
})
}
return entries
}
// filterEntries фильтрует записи по нужным переменным и уровням давления
func (d *PartialDownloader) filterEntries(entries []idxEntry) []idxEntry {
var filtered []idxEntry
for _, e := range entries {
isNeededVar := false
for _, v := range d.Variables {
if v == e.Variable {
isNeededVar = true
break
}
}
isPressureLevel := strings.HasSuffix(e.Level, " mb")
if isNeededVar && isPressureLevel {
filtered = append(filtered, e)
}
}
return filtered
}
// Вспомогательная функция для выполнения запроса с повторами
func (d *PartialDownloader) doWithRetry(ctx context.Context, req *http.Request) (*http.Response, error) {
var resp *http.Response
var err error
backoff := 1 * time.Second
maxRetries := 3
for i := 0; i < maxRetries; i++ {
resp, err = d.Client.Do(req)
if err == nil && resp.StatusCode < 500 {
return resp, nil
}
if resp != nil {
resp.Body.Close()
}
log.Ctx(ctx).Warn("retry download", zap.Int("attempt", i+1), zap.Error(err))
select {
case <-time.After(backoff):
backoff *= 2
case <-ctx.Done():
return nil, ctx.Err()
}
}
return nil, err
}
// downloadRange загружает диапазон байтов из URL
func (d *PartialDownloader) downloadRange(ctx context.Context, url string, start, end int64, out io.Writer) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return err
}
rangeHeader := fmt.Sprintf("bytes=%d-", start)
if end > 0 {
rangeHeader = fmt.Sprintf("bytes=%d-%d", start, end)
}
req.Header.Set("Range", rangeHeader)
resp, err := d.Client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
return errcodes.Wrap(errcodes.ErrDownload, "bad status: "+resp.Status)
}
_, err = io.Copy(out, resp.Body)
return err
}
func (d *PartialDownloader) downloadFieldsFromURL(ctx context.Context, url string, dst string, step int) (err error) {
idxURL := url + ".idx"
tmp := dst + ".part"
if info, err := os.Stat(dst); err == nil && info.Size() > 0 {
return nil
}
reqIdx, _ := http.NewRequestWithContext(ctx, http.MethodGet, idxURL, nil)
respIdx, err := d.doWithRetry(ctx, reqIdx)
if err != nil {
return errcodes.Wrap(err, "failed to get idx")
}
defer respIdx.Body.Close()
idxBody, _ := io.ReadAll(respIdx.Body)
entries := d.parseIdx(idxBody)
filtered := d.filterEntries(entries)
if len(filtered) == 0 {
return nil
}
var totalBytes int64
type chunk struct{ start, end int64 }
chunks := make([]chunk, 0, len(filtered))
for _, entry := range filtered {
var endByte int64 = -1
for j, e := range entries {
if e.ByteStart == entry.ByteStart && j+1 < len(entries) {
endByte = entries[j+1].ByteStart - 1
break
}
}
chunks = append(chunks, chunk{entry.ByteStart, endByte})
if endByte > 0 {
totalBytes += (endByte - entry.ByteStart + 1)
}
}
f, err := os.OpenFile(tmp, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
var downloaded int64
err = func() error {
defer f.Close()
bufWriter := bufio.NewWriterSize(f, 1024*1024)
for i, c := range chunks {
countingWriter := &proxyWriter{
Writer: bufWriter,
OnWrite: func(n int) {
downloaded += int64(n)
if totalBytes > 0 && i%20 == 0 {
pct := float64(downloaded) / float64(totalBytes) * 100
log.Ctx(ctx).Debug("download progress",
zap.Int("step", step),
zap.String("pct", fmt.Sprintf("%.1f%%", pct)))
}
},
}
if err := d.downloadRange(ctx, url, c.start, c.end, countingWriter); err != nil {
return err
}
}
return bufWriter.Flush()
}()
if err != nil {
f.Close()
os.Remove(tmp)
return err
}
return d.safeRename(tmp, dst)
}
type proxyWriter struct {
io.Writer
OnWrite func(int)
}
func (p *proxyWriter) Write(data []byte) (int, error) {
n, err := p.Writer.Write(data)
if n > 0 && p.OnWrite != nil {
p.OnWrite(n)
}
return n, err
}
func (d *PartialDownloader) safeRename(src, dst string) error {
var lastErr error
for i := 0; i < 5; i++ {
if err := os.Rename(src, dst); err == nil {
return nil
} else {
lastErr = err
}
time.Sleep(150 * time.Millisecond)
}
return fmt.Errorf("rename failed: %w", lastErr)
}
// Run запускает загрузку всех необходимых файлов (pgrb2 + pgrb2b)
func (d *PartialDownloader) Run(ctx context.Context, run time.Time) error {
log.Ctx(ctx).Info("starting partial download",
zap.Time("run", run),
zap.Strings("variables", d.Variables))
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.Parallel)
steps := d.ds.Steps()
for _, step := range steps {
step := step
// Download primary pgrb2
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := d.ds.GribURL(run, step)
dst := filepath.Join(d.Dir, d.ds.FileName(run, step))
return d.downloadFieldsFromURL(ctx, url, dst, step)
})
// Download secondary pgrb2b
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := d.ds.GribURLB(run, step)
dst := filepath.Join(d.Dir, d.ds.FileNameB(run, step))
return d.downloadFieldsFromURL(ctx, url, dst, step)
})
}
return g.Wait()
}
// GetLatestModelRun находит последний доступный прогноз GFS
func GetLatestModelRun(ctx context.Context, dc *DatasetConfig) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
client := &http.Client{Timeout: 10 * time.Second}
for i := 0; i < 8; i++ {
url := dc.GribURL(current, dc.MaxHour)
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
current = current.Add(-6 * time.Hour)
continue
}
resp, err := client.Do(req)
if err == nil && resp.StatusCode == http.StatusOK {
resp.Body.Close()
log.Ctx(ctx).Info("found latest model run", zap.Time("run", current))
return current, nil
}
if resp != nil {
resp.Body.Close()
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, errcodes.Wrap(errcodes.ErrDownload, "no recent GFS forecast found")
}