276 lines
6.1 KiB
Go
276 lines
6.1 KiB
Go
// Package async implements the asynchronous prediction endpoints
|
|
// (/api/v1/predictions{,/{id}}) and the worker pool that executes them.
|
|
//
|
|
// Each enqueued request is assigned a job ID; the result is held in
|
|
// memory for a configurable TTL after completion.
|
|
package async
|
|
|
|
import (
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
|
|
"predictor-refactored/internal/api/v2"
|
|
"predictor-refactored/internal/datasets"
|
|
"predictor-refactored/internal/elevation"
|
|
"predictor-refactored/internal/metrics"
|
|
)
|
|
|
|
// Status is the lifecycle state of a prediction job.
|
|
type Status string
|
|
|
|
const (
|
|
StatusPending Status = "pending"
|
|
StatusRunning Status = "running"
|
|
StatusComplete Status = "complete"
|
|
StatusFailed Status = "failed"
|
|
StatusCancelled Status = "cancelled"
|
|
)
|
|
|
|
// JobInfo is the externally-visible snapshot of one prediction job.
|
|
type JobInfo struct {
|
|
ID string `json:"id"`
|
|
Status Status `json:"status"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
StartedAt *time.Time `json:"started_at,omitempty"`
|
|
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
|
Error string `json:"error,omitempty"`
|
|
Result *v2.PredictionResponse `json:"result,omitempty"`
|
|
}
|
|
|
|
type job struct {
|
|
id string
|
|
req v2.PredictionRequest
|
|
createdAt time.Time
|
|
|
|
mu sync.Mutex
|
|
status Status
|
|
startedAt time.Time
|
|
completedAt time.Time
|
|
errStr string
|
|
result *v2.PredictionResponse
|
|
cancel chan struct{}
|
|
}
|
|
|
|
func (j *job) snapshot() JobInfo {
|
|
j.mu.Lock()
|
|
defer j.mu.Unlock()
|
|
info := JobInfo{
|
|
ID: j.id,
|
|
Status: j.status,
|
|
CreatedAt: j.createdAt,
|
|
Error: j.errStr,
|
|
Result: j.result,
|
|
}
|
|
if !j.startedAt.IsZero() {
|
|
t := j.startedAt
|
|
info.StartedAt = &t
|
|
}
|
|
if !j.completedAt.IsZero() {
|
|
t := j.completedAt
|
|
info.CompletedAt = &t
|
|
}
|
|
return info
|
|
}
|
|
|
|
// Manager runs a fixed pool of workers to execute prediction jobs and
|
|
// retains their results for the configured TTL.
|
|
type Manager struct {
|
|
mgr *datasets.Manager
|
|
elev *elevation.Dataset
|
|
metrics metrics.Sink
|
|
log *zap.Logger
|
|
|
|
queue chan *job
|
|
ttl time.Duration
|
|
|
|
jobsMu sync.RWMutex
|
|
jobs map[string]*job
|
|
|
|
inflight atomic.Int64
|
|
closed chan struct{}
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
// Config controls Manager construction.
|
|
type Config struct {
|
|
// Workers is the maximum concurrent prediction executions.
|
|
Workers int
|
|
// QueueSize bounds the number of jobs waiting to start.
|
|
QueueSize int
|
|
// ResultTTL is how long completed/failed jobs are retained in memory.
|
|
ResultTTL time.Duration
|
|
}
|
|
|
|
// New constructs a Manager with the given config and starts the workers.
|
|
func New(cfg Config, mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Manager {
|
|
if cfg.Workers <= 0 {
|
|
cfg.Workers = 4
|
|
}
|
|
if cfg.QueueSize <= 0 {
|
|
cfg.QueueSize = 64
|
|
}
|
|
if cfg.ResultTTL <= 0 {
|
|
cfg.ResultTTL = time.Hour
|
|
}
|
|
if sink == nil {
|
|
sink = metrics.Noop()
|
|
}
|
|
if log == nil {
|
|
log = zap.NewNop()
|
|
}
|
|
m := &Manager{
|
|
mgr: mgr, elev: elev, metrics: sink, log: log,
|
|
queue: make(chan *job, cfg.QueueSize),
|
|
jobs: make(map[string]*job),
|
|
ttl: cfg.ResultTTL,
|
|
closed: make(chan struct{}),
|
|
}
|
|
for range cfg.Workers {
|
|
m.wg.Add(1)
|
|
go m.worker()
|
|
}
|
|
m.wg.Add(1)
|
|
go m.evictor()
|
|
return m
|
|
}
|
|
|
|
// Enqueue creates a new job from req and returns its snapshot.
|
|
// Returns false when the queue is full.
|
|
func (m *Manager) Enqueue(req v2.PredictionRequest) (JobInfo, bool) {
|
|
j := &job{
|
|
id: uuid.New().String(),
|
|
req: req,
|
|
createdAt: time.Now().UTC(),
|
|
status: StatusPending,
|
|
cancel: make(chan struct{}),
|
|
}
|
|
m.jobsMu.Lock()
|
|
m.jobs[j.id] = j
|
|
m.jobsMu.Unlock()
|
|
|
|
select {
|
|
case m.queue <- j:
|
|
return j.snapshot(), true
|
|
default:
|
|
// Queue full — mark the job failed and return it.
|
|
j.mu.Lock()
|
|
j.status = StatusFailed
|
|
j.errStr = "prediction queue full"
|
|
j.completedAt = time.Now().UTC()
|
|
j.mu.Unlock()
|
|
return j.snapshot(), false
|
|
}
|
|
}
|
|
|
|
// Get returns a job's snapshot.
|
|
func (m *Manager) Get(id string) (JobInfo, bool) {
|
|
m.jobsMu.RLock()
|
|
j, ok := m.jobs[id]
|
|
m.jobsMu.RUnlock()
|
|
if !ok {
|
|
return JobInfo{}, false
|
|
}
|
|
return j.snapshot(), true
|
|
}
|
|
|
|
// Cancel marks a not-yet-started job as cancelled. Returns false when the
|
|
// job is unknown or already terminal.
|
|
func (m *Manager) Cancel(id string) bool {
|
|
m.jobsMu.RLock()
|
|
j, ok := m.jobs[id]
|
|
m.jobsMu.RUnlock()
|
|
if !ok {
|
|
return false
|
|
}
|
|
j.mu.Lock()
|
|
terminal := j.status == StatusComplete || j.status == StatusFailed || j.status == StatusCancelled
|
|
if terminal {
|
|
j.mu.Unlock()
|
|
return false
|
|
}
|
|
j.status = StatusCancelled
|
|
j.completedAt = time.Now().UTC()
|
|
j.mu.Unlock()
|
|
close(j.cancel)
|
|
return true
|
|
}
|
|
|
|
// Inflight returns the count of running jobs.
|
|
func (m *Manager) Inflight() int64 { return m.inflight.Load() }
|
|
|
|
// Close shuts down workers and the evictor.
|
|
func (m *Manager) Close() {
|
|
close(m.closed)
|
|
close(m.queue)
|
|
m.wg.Wait()
|
|
}
|
|
|
|
func (m *Manager) worker() {
|
|
defer m.wg.Done()
|
|
for j := range m.queue {
|
|
// Check cancellation before starting.
|
|
j.mu.Lock()
|
|
cancelled := j.status == StatusCancelled
|
|
j.mu.Unlock()
|
|
if cancelled {
|
|
continue
|
|
}
|
|
m.inflight.Add(1)
|
|
j.mu.Lock()
|
|
j.status = StatusRunning
|
|
j.startedAt = time.Now().UTC()
|
|
j.mu.Unlock()
|
|
|
|
resp, err := v2.Run(m.mgr, m.elev, j.req)
|
|
|
|
j.mu.Lock()
|
|
j.completedAt = time.Now().UTC()
|
|
if err != nil {
|
|
j.status = StatusFailed
|
|
j.errStr = err.Error()
|
|
} else {
|
|
j.status = StatusComplete
|
|
j.result = resp
|
|
}
|
|
j.mu.Unlock()
|
|
m.inflight.Add(-1)
|
|
|
|
if err == nil {
|
|
m.metrics.Prediction("async", j.completedAt.Sub(j.startedAt), nil)
|
|
} else {
|
|
m.metrics.Prediction("async", j.completedAt.Sub(j.startedAt), err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Manager) evictor() {
|
|
defer m.wg.Done()
|
|
ticker := time.NewTicker(m.ttl / 4)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-m.closed:
|
|
return
|
|
case <-ticker.C:
|
|
m.evictExpired()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Manager) evictExpired() {
|
|
now := time.Now().UTC()
|
|
m.jobsMu.Lock()
|
|
defer m.jobsMu.Unlock()
|
|
for id, j := range m.jobs {
|
|
j.mu.Lock()
|
|
expired := !j.completedAt.IsZero() && now.Sub(j.completedAt) > m.ttl
|
|
j.mu.Unlock()
|
|
if expired {
|
|
delete(m.jobs, id)
|
|
}
|
|
}
|
|
}
|