279 lines
6.2 KiB
Go
279 lines
6.2 KiB
Go
// Package async runs profile-driven predictions on a bounded worker pool and
|
|
// retains their results in memory for a configurable TTL. It is the engine
|
|
// behind the asynchronous prediction endpoints; the HTTP surface itself is
|
|
// the ogen-generated server in the parent package.
|
|
//
|
|
// The package is decoupled from the request/response wire types: a RunFunc is
|
|
// injected at construction, so this file imports only the generated API types
|
|
// it stores and returns.
|
|
package async
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"go.uber.org/zap"
|
|
|
|
"predictor-refactored/internal/metrics"
|
|
apirest "predictor-refactored/pkg/rest"
|
|
)
|
|
|
|
// RunFunc executes one prediction synchronously.
|
|
type RunFunc func(req *apirest.PredictionV2Request) (*apirest.PredictionV2Response, error)
|
|
|
|
// Status is the lifecycle state of a prediction job.
|
|
type Status string
|
|
|
|
const (
|
|
StatusPending Status = "pending"
|
|
StatusRunning Status = "running"
|
|
StatusComplete Status = "complete"
|
|
StatusFailed Status = "failed"
|
|
StatusCancelled Status = "cancelled"
|
|
)
|
|
|
|
// JobInfo is a snapshot of one prediction job.
|
|
type JobInfo struct {
|
|
ID string
|
|
Status Status
|
|
CreatedAt time.Time
|
|
StartedAt *time.Time
|
|
CompletedAt *time.Time
|
|
Error string
|
|
Result *apirest.PredictionV2Response
|
|
}
|
|
|
|
type job struct {
|
|
id string
|
|
req *apirest.PredictionV2Request
|
|
createdAt time.Time
|
|
|
|
mu sync.Mutex
|
|
status Status
|
|
startedAt time.Time
|
|
completedAt time.Time
|
|
errStr string
|
|
result *apirest.PredictionV2Response
|
|
}
|
|
|
|
func (j *job) snapshot() JobInfo {
|
|
j.mu.Lock()
|
|
defer j.mu.Unlock()
|
|
info := JobInfo{
|
|
ID: j.id, Status: j.status, CreatedAt: j.createdAt,
|
|
Error: j.errStr, Result: j.result,
|
|
}
|
|
if !j.startedAt.IsZero() {
|
|
t := j.startedAt
|
|
info.StartedAt = &t
|
|
}
|
|
if !j.completedAt.IsZero() {
|
|
t := j.completedAt
|
|
info.CompletedAt = &t
|
|
}
|
|
return info
|
|
}
|
|
|
|
// Manager runs a fixed pool of workers and retains job results for a TTL.
|
|
type Manager struct {
|
|
run RunFunc
|
|
metrics metrics.Sink
|
|
log *zap.Logger
|
|
|
|
queue chan *job
|
|
ttl time.Duration
|
|
|
|
jobsMu sync.RWMutex
|
|
jobs map[string]*job
|
|
|
|
inflight atomic.Int64
|
|
closed chan struct{}
|
|
wg sync.WaitGroup
|
|
}
|
|
|
|
// Config controls Manager construction.
|
|
type Config struct {
|
|
Workers int // max concurrent executions
|
|
QueueSize int // pending-queue bound
|
|
ResultTTL time.Duration // retention of terminal jobs
|
|
}
|
|
|
|
// New constructs a Manager and starts its workers. run executes one
|
|
// prediction; sink and log may be nil.
|
|
func New(cfg Config, run RunFunc, sink metrics.Sink, log *zap.Logger) *Manager {
|
|
if cfg.Workers <= 0 {
|
|
cfg.Workers = 4
|
|
}
|
|
if cfg.QueueSize <= 0 {
|
|
cfg.QueueSize = 64
|
|
}
|
|
if cfg.ResultTTL <= 0 {
|
|
cfg.ResultTTL = time.Hour
|
|
}
|
|
if sink == nil {
|
|
sink = metrics.Noop()
|
|
}
|
|
if log == nil {
|
|
log = zap.NewNop()
|
|
}
|
|
m := &Manager{
|
|
run: run, metrics: sink, log: log,
|
|
queue: make(chan *job, cfg.QueueSize),
|
|
jobs: make(map[string]*job),
|
|
ttl: cfg.ResultTTL,
|
|
closed: make(chan struct{}),
|
|
}
|
|
for range cfg.Workers {
|
|
m.wg.Add(1)
|
|
go m.worker()
|
|
}
|
|
m.wg.Add(1)
|
|
go m.evictor()
|
|
return m
|
|
}
|
|
|
|
// Enqueue creates a job from req and returns its snapshot. The bool is false
|
|
// when the queue is full (the returned job is marked failed).
|
|
func (m *Manager) Enqueue(req *apirest.PredictionV2Request) (JobInfo, bool) {
|
|
j := &job{
|
|
id: uuid.New().String(),
|
|
req: req,
|
|
createdAt: time.Now().UTC(),
|
|
status: StatusPending,
|
|
}
|
|
m.jobsMu.Lock()
|
|
m.jobs[j.id] = j
|
|
m.jobsMu.Unlock()
|
|
|
|
select {
|
|
case m.queue <- j:
|
|
return j.snapshot(), true
|
|
default:
|
|
j.mu.Lock()
|
|
j.status = StatusFailed
|
|
j.errStr = "prediction queue full"
|
|
j.completedAt = time.Now().UTC()
|
|
j.mu.Unlock()
|
|
return j.snapshot(), false
|
|
}
|
|
}
|
|
|
|
// Get returns a job's snapshot.
|
|
func (m *Manager) Get(id string) (JobInfo, bool) {
|
|
m.jobsMu.RLock()
|
|
j, ok := m.jobs[id]
|
|
m.jobsMu.RUnlock()
|
|
if !ok {
|
|
return JobInfo{}, false
|
|
}
|
|
return j.snapshot(), true
|
|
}
|
|
|
|
// Cancel marks a still-queued job cancelled. Returns false when the job is
|
|
// unknown or already running/terminal — a running prediction cannot be
|
|
// interrupted (the worker would otherwise overwrite the cancelled status with
|
|
// its result), so callers get an honest "too late" rather than a 204 that the
|
|
// worker silently undoes.
|
|
func (m *Manager) Cancel(id string) bool {
|
|
m.jobsMu.RLock()
|
|
j, ok := m.jobs[id]
|
|
m.jobsMu.RUnlock()
|
|
if !ok {
|
|
return false
|
|
}
|
|
j.mu.Lock()
|
|
defer j.mu.Unlock()
|
|
if j.status != StatusPending {
|
|
return false
|
|
}
|
|
j.status = StatusCancelled
|
|
j.completedAt = time.Now().UTC()
|
|
return true
|
|
}
|
|
|
|
// Inflight returns the number of running jobs.
|
|
func (m *Manager) Inflight() int64 { return m.inflight.Load() }
|
|
|
|
// Close stops the workers and the evictor.
|
|
func (m *Manager) Close() {
|
|
close(m.closed)
|
|
close(m.queue)
|
|
m.wg.Wait()
|
|
}
|
|
|
|
func (m *Manager) worker() {
|
|
defer m.wg.Done()
|
|
for j := range m.queue {
|
|
j.mu.Lock()
|
|
cancelled := j.status == StatusCancelled
|
|
if !cancelled {
|
|
j.status = StatusRunning
|
|
j.startedAt = time.Now().UTC()
|
|
}
|
|
j.mu.Unlock()
|
|
if cancelled {
|
|
continue
|
|
}
|
|
m.execute(j)
|
|
}
|
|
}
|
|
|
|
// execute runs one job, recovering from a panic in the injected RunFunc so a
|
|
// single bad prediction can't leak the inflight counter or kill the worker.
|
|
func (m *Manager) execute(j *job) {
|
|
m.inflight.Add(1)
|
|
defer m.inflight.Add(-1)
|
|
|
|
resp, err := func() (resp *apirest.PredictionV2Response, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
err = fmt.Errorf("prediction panicked: %v", r)
|
|
}
|
|
}()
|
|
return m.run(j.req)
|
|
}()
|
|
|
|
j.mu.Lock()
|
|
j.completedAt = time.Now().UTC()
|
|
if err != nil {
|
|
j.status = StatusFailed
|
|
j.errStr = err.Error()
|
|
} else {
|
|
j.status = StatusComplete
|
|
j.result = resp
|
|
}
|
|
dur := j.completedAt.Sub(j.startedAt)
|
|
j.mu.Unlock()
|
|
m.metrics.Prediction("async", dur, err)
|
|
}
|
|
|
|
func (m *Manager) evictor() {
|
|
defer m.wg.Done()
|
|
ticker := time.NewTicker(m.ttl / 4)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-m.closed:
|
|
return
|
|
case <-ticker.C:
|
|
m.evictExpired()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Manager) evictExpired() {
|
|
now := time.Now().UTC()
|
|
m.jobsMu.Lock()
|
|
defer m.jobsMu.Unlock()
|
|
for id, j := range m.jobs {
|
|
j.mu.Lock()
|
|
expired := !j.completedAt.IsZero() && now.Sub(j.completedAt) > m.ttl
|
|
j.mu.Unlock()
|
|
if expired {
|
|
delete(m.jobs, id)
|
|
}
|
|
}
|
|
}
|