predictor/internal/api/async/manager.go

279 lines
6.2 KiB
Go

// Package async runs profile-driven predictions on a bounded worker pool and
// retains their results in memory for a configurable TTL. It is the engine
// behind the asynchronous prediction endpoints; the HTTP surface itself is
// the ogen-generated server in the parent package.
//
// The package is decoupled from the request/response wire types: a RunFunc is
// injected at construction, so this file imports only the generated API types
// it stores and returns.
package async
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"predictor-refactored/internal/metrics"
apirest "predictor-refactored/pkg/rest"
)
// RunFunc executes one prediction synchronously.
type RunFunc func(req *apirest.PredictionV2Request) (*apirest.PredictionV2Response, error)
// Status is the lifecycle state of a prediction job.
type Status string
const (
StatusPending Status = "pending"
StatusRunning Status = "running"
StatusComplete Status = "complete"
StatusFailed Status = "failed"
StatusCancelled Status = "cancelled"
)
// JobInfo is a snapshot of one prediction job.
type JobInfo struct {
ID string
Status Status
CreatedAt time.Time
StartedAt *time.Time
CompletedAt *time.Time
Error string
Result *apirest.PredictionV2Response
}
type job struct {
id string
req *apirest.PredictionV2Request
createdAt time.Time
mu sync.Mutex
status Status
startedAt time.Time
completedAt time.Time
errStr string
result *apirest.PredictionV2Response
}
func (j *job) snapshot() JobInfo {
j.mu.Lock()
defer j.mu.Unlock()
info := JobInfo{
ID: j.id, Status: j.status, CreatedAt: j.createdAt,
Error: j.errStr, Result: j.result,
}
if !j.startedAt.IsZero() {
t := j.startedAt
info.StartedAt = &t
}
if !j.completedAt.IsZero() {
t := j.completedAt
info.CompletedAt = &t
}
return info
}
// Manager runs a fixed pool of workers and retains job results for a TTL.
type Manager struct {
run RunFunc
metrics metrics.Sink
log *zap.Logger
queue chan *job
ttl time.Duration
jobsMu sync.RWMutex
jobs map[string]*job
inflight atomic.Int64
closed chan struct{}
wg sync.WaitGroup
}
// Config controls Manager construction.
type Config struct {
Workers int // max concurrent executions
QueueSize int // pending-queue bound
ResultTTL time.Duration // retention of terminal jobs
}
// New constructs a Manager and starts its workers. run executes one
// prediction; sink and log may be nil.
func New(cfg Config, run RunFunc, sink metrics.Sink, log *zap.Logger) *Manager {
if cfg.Workers <= 0 {
cfg.Workers = 4
}
if cfg.QueueSize <= 0 {
cfg.QueueSize = 64
}
if cfg.ResultTTL <= 0 {
cfg.ResultTTL = time.Hour
}
if sink == nil {
sink = metrics.Noop()
}
if log == nil {
log = zap.NewNop()
}
m := &Manager{
run: run, metrics: sink, log: log,
queue: make(chan *job, cfg.QueueSize),
jobs: make(map[string]*job),
ttl: cfg.ResultTTL,
closed: make(chan struct{}),
}
for range cfg.Workers {
m.wg.Add(1)
go m.worker()
}
m.wg.Add(1)
go m.evictor()
return m
}
// Enqueue creates a job from req and returns its snapshot. The bool is false
// when the queue is full (the returned job is marked failed).
func (m *Manager) Enqueue(req *apirest.PredictionV2Request) (JobInfo, bool) {
j := &job{
id: uuid.New().String(),
req: req,
createdAt: time.Now().UTC(),
status: StatusPending,
}
m.jobsMu.Lock()
m.jobs[j.id] = j
m.jobsMu.Unlock()
select {
case m.queue <- j:
return j.snapshot(), true
default:
j.mu.Lock()
j.status = StatusFailed
j.errStr = "prediction queue full"
j.completedAt = time.Now().UTC()
j.mu.Unlock()
return j.snapshot(), false
}
}
// Get returns a job's snapshot.
func (m *Manager) Get(id string) (JobInfo, bool) {
m.jobsMu.RLock()
j, ok := m.jobs[id]
m.jobsMu.RUnlock()
if !ok {
return JobInfo{}, false
}
return j.snapshot(), true
}
// Cancel marks a still-queued job cancelled. Returns false when the job is
// unknown or already running/terminal — a running prediction cannot be
// interrupted (the worker would otherwise overwrite the cancelled status with
// its result), so callers get an honest "too late" rather than a 204 that the
// worker silently undoes.
func (m *Manager) Cancel(id string) bool {
m.jobsMu.RLock()
j, ok := m.jobs[id]
m.jobsMu.RUnlock()
if !ok {
return false
}
j.mu.Lock()
defer j.mu.Unlock()
if j.status != StatusPending {
return false
}
j.status = StatusCancelled
j.completedAt = time.Now().UTC()
return true
}
// Inflight returns the number of running jobs.
func (m *Manager) Inflight() int64 { return m.inflight.Load() }
// Close stops the workers and the evictor.
func (m *Manager) Close() {
close(m.closed)
close(m.queue)
m.wg.Wait()
}
func (m *Manager) worker() {
defer m.wg.Done()
for j := range m.queue {
j.mu.Lock()
cancelled := j.status == StatusCancelled
if !cancelled {
j.status = StatusRunning
j.startedAt = time.Now().UTC()
}
j.mu.Unlock()
if cancelled {
continue
}
m.execute(j)
}
}
// execute runs one job, recovering from a panic in the injected RunFunc so a
// single bad prediction can't leak the inflight counter or kill the worker.
func (m *Manager) execute(j *job) {
m.inflight.Add(1)
defer m.inflight.Add(-1)
resp, err := func() (resp *apirest.PredictionV2Response, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("prediction panicked: %v", r)
}
}()
return m.run(j.req)
}()
j.mu.Lock()
j.completedAt = time.Now().UTC()
if err != nil {
j.status = StatusFailed
j.errStr = err.Error()
} else {
j.status = StatusComplete
j.result = resp
}
dur := j.completedAt.Sub(j.startedAt)
j.mu.Unlock()
m.metrics.Prediction("async", dur, err)
}
func (m *Manager) evictor() {
defer m.wg.Done()
ticker := time.NewTicker(m.ttl / 4)
defer ticker.Stop()
for {
select {
case <-m.closed:
return
case <-ticker.C:
m.evictExpired()
}
}
}
func (m *Manager) evictExpired() {
now := time.Now().UTC()
m.jobsMu.Lock()
defer m.jobsMu.Unlock()
for id, j := range m.jobs {
j.mu.Lock()
expired := !j.completedAt.IsZero() && now.Sub(j.completedAt) > m.ttl
j.mu.Unlock()
if expired {
delete(m.jobs, id)
}
}
}