engine refactor
This commit is contained in:
parent
9e663db9dc
commit
81b8e763bd
37 changed files with 3532 additions and 1639 deletions
63
internal/api/async/handler.go
Normal file
63
internal/api/async/handler.go
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
package async
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"predictor-refactored/internal/api/httpjson"
|
||||
"predictor-refactored/internal/api/v2"
|
||||
)
|
||||
|
||||
// Handler implements the /api/v1/predictions{,/{id}} endpoints.
|
||||
type Handler struct {
|
||||
mgr *Manager
|
||||
}
|
||||
|
||||
// NewHandler wires a handler.
|
||||
func NewHandler(mgr *Manager) *Handler { return &Handler{mgr: mgr} }
|
||||
|
||||
// Register installs the async routes on mux.
|
||||
func (h *Handler) Register(mux *http.ServeMux) {
|
||||
mux.HandleFunc("POST /api/v1/predictions", h.create)
|
||||
mux.HandleFunc("GET /api/v1/predictions/{id}", h.get)
|
||||
mux.HandleFunc("DELETE /api/v1/predictions/{id}", h.cancel)
|
||||
}
|
||||
|
||||
func (h *Handler) create(w http.ResponseWriter, r *http.Request) {
|
||||
var req v2.PredictionRequest
|
||||
dec := json.NewDecoder(r.Body)
|
||||
dec.DisallowUnknownFields()
|
||||
if err := dec.Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid body: "+err.Error())
|
||||
return
|
||||
}
|
||||
info, accepted := h.mgr.Enqueue(req)
|
||||
if !accepted {
|
||||
writeJSON(w, http.StatusServiceUnavailable, info)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Location", "/api/v1/predictions/"+info.ID)
|
||||
writeJSON(w, http.StatusAccepted, info)
|
||||
}
|
||||
|
||||
func (h *Handler) get(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
info, ok := h.mgr.Get(id)
|
||||
if !ok {
|
||||
writeError(w, http.StatusNotFound, "prediction job not found")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, info)
|
||||
}
|
||||
|
||||
func (h *Handler) cancel(w http.ResponseWriter, r *http.Request) {
|
||||
id := r.PathValue("id")
|
||||
if !h.mgr.Cancel(id) {
|
||||
writeError(w, http.StatusConflict, "job not found or already terminal")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
var writeJSON = httpjson.Write
|
||||
var writeError = httpjson.Error
|
||||
276
internal/api/async/manager.go
Normal file
276
internal/api/async/manager.go
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
// Package async implements the asynchronous prediction endpoints
|
||||
// (/api/v1/predictions{,/{id}}) and the worker pool that executes them.
|
||||
//
|
||||
// Each enqueued request is assigned a job ID; the result is held in
|
||||
// memory for a configurable TTL after completion.
|
||||
package async
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"predictor-refactored/internal/api/v2"
|
||||
"predictor-refactored/internal/datasets"
|
||||
"predictor-refactored/internal/elevation"
|
||||
"predictor-refactored/internal/metrics"
|
||||
)
|
||||
|
||||
// Status is the lifecycle state of a prediction job.
|
||||
type Status string
|
||||
|
||||
const (
|
||||
StatusPending Status = "pending"
|
||||
StatusRunning Status = "running"
|
||||
StatusComplete Status = "complete"
|
||||
StatusFailed Status = "failed"
|
||||
StatusCancelled Status = "cancelled"
|
||||
)
|
||||
|
||||
// JobInfo is the externally-visible snapshot of one prediction job.
|
||||
type JobInfo struct {
|
||||
ID string `json:"id"`
|
||||
Status Status `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
StartedAt *time.Time `json:"started_at,omitempty"`
|
||||
CompletedAt *time.Time `json:"completed_at,omitempty"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Result *v2.PredictionResponse `json:"result,omitempty"`
|
||||
}
|
||||
|
||||
type job struct {
|
||||
id string
|
||||
req v2.PredictionRequest
|
||||
createdAt time.Time
|
||||
|
||||
mu sync.Mutex
|
||||
status Status
|
||||
startedAt time.Time
|
||||
completedAt time.Time
|
||||
errStr string
|
||||
result *v2.PredictionResponse
|
||||
cancel chan struct{}
|
||||
}
|
||||
|
||||
func (j *job) snapshot() JobInfo {
|
||||
j.mu.Lock()
|
||||
defer j.mu.Unlock()
|
||||
info := JobInfo{
|
||||
ID: j.id,
|
||||
Status: j.status,
|
||||
CreatedAt: j.createdAt,
|
||||
Error: j.errStr,
|
||||
Result: j.result,
|
||||
}
|
||||
if !j.startedAt.IsZero() {
|
||||
t := j.startedAt
|
||||
info.StartedAt = &t
|
||||
}
|
||||
if !j.completedAt.IsZero() {
|
||||
t := j.completedAt
|
||||
info.CompletedAt = &t
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
// Manager runs a fixed pool of workers to execute prediction jobs and
|
||||
// retains their results for the configured TTL.
|
||||
type Manager struct {
|
||||
mgr *datasets.Manager
|
||||
elev *elevation.Dataset
|
||||
metrics metrics.Sink
|
||||
log *zap.Logger
|
||||
|
||||
queue chan *job
|
||||
ttl time.Duration
|
||||
|
||||
jobsMu sync.RWMutex
|
||||
jobs map[string]*job
|
||||
|
||||
inflight atomic.Int64
|
||||
closed chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// Config controls Manager construction.
|
||||
type Config struct {
|
||||
// Workers is the maximum concurrent prediction executions.
|
||||
Workers int
|
||||
// QueueSize bounds the number of jobs waiting to start.
|
||||
QueueSize int
|
||||
// ResultTTL is how long completed/failed jobs are retained in memory.
|
||||
ResultTTL time.Duration
|
||||
}
|
||||
|
||||
// New constructs a Manager with the given config and starts the workers.
|
||||
func New(cfg Config, mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Manager {
|
||||
if cfg.Workers <= 0 {
|
||||
cfg.Workers = 4
|
||||
}
|
||||
if cfg.QueueSize <= 0 {
|
||||
cfg.QueueSize = 64
|
||||
}
|
||||
if cfg.ResultTTL <= 0 {
|
||||
cfg.ResultTTL = time.Hour
|
||||
}
|
||||
if sink == nil {
|
||||
sink = metrics.Noop()
|
||||
}
|
||||
if log == nil {
|
||||
log = zap.NewNop()
|
||||
}
|
||||
m := &Manager{
|
||||
mgr: mgr, elev: elev, metrics: sink, log: log,
|
||||
queue: make(chan *job, cfg.QueueSize),
|
||||
jobs: make(map[string]*job),
|
||||
ttl: cfg.ResultTTL,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
for range cfg.Workers {
|
||||
m.wg.Add(1)
|
||||
go m.worker()
|
||||
}
|
||||
m.wg.Add(1)
|
||||
go m.evictor()
|
||||
return m
|
||||
}
|
||||
|
||||
// Enqueue creates a new job from req and returns its snapshot.
|
||||
// Returns false when the queue is full.
|
||||
func (m *Manager) Enqueue(req v2.PredictionRequest) (JobInfo, bool) {
|
||||
j := &job{
|
||||
id: uuid.New().String(),
|
||||
req: req,
|
||||
createdAt: time.Now().UTC(),
|
||||
status: StatusPending,
|
||||
cancel: make(chan struct{}),
|
||||
}
|
||||
m.jobsMu.Lock()
|
||||
m.jobs[j.id] = j
|
||||
m.jobsMu.Unlock()
|
||||
|
||||
select {
|
||||
case m.queue <- j:
|
||||
return j.snapshot(), true
|
||||
default:
|
||||
// Queue full — mark the job failed and return it.
|
||||
j.mu.Lock()
|
||||
j.status = StatusFailed
|
||||
j.errStr = "prediction queue full"
|
||||
j.completedAt = time.Now().UTC()
|
||||
j.mu.Unlock()
|
||||
return j.snapshot(), false
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a job's snapshot.
|
||||
func (m *Manager) Get(id string) (JobInfo, bool) {
|
||||
m.jobsMu.RLock()
|
||||
j, ok := m.jobs[id]
|
||||
m.jobsMu.RUnlock()
|
||||
if !ok {
|
||||
return JobInfo{}, false
|
||||
}
|
||||
return j.snapshot(), true
|
||||
}
|
||||
|
||||
// Cancel marks a not-yet-started job as cancelled. Returns false when the
|
||||
// job is unknown or already terminal.
|
||||
func (m *Manager) Cancel(id string) bool {
|
||||
m.jobsMu.RLock()
|
||||
j, ok := m.jobs[id]
|
||||
m.jobsMu.RUnlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
j.mu.Lock()
|
||||
terminal := j.status == StatusComplete || j.status == StatusFailed || j.status == StatusCancelled
|
||||
if terminal {
|
||||
j.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
j.status = StatusCancelled
|
||||
j.completedAt = time.Now().UTC()
|
||||
j.mu.Unlock()
|
||||
close(j.cancel)
|
||||
return true
|
||||
}
|
||||
|
||||
// Inflight returns the count of running jobs.
|
||||
func (m *Manager) Inflight() int64 { return m.inflight.Load() }
|
||||
|
||||
// Close shuts down workers and the evictor.
|
||||
func (m *Manager) Close() {
|
||||
close(m.closed)
|
||||
close(m.queue)
|
||||
m.wg.Wait()
|
||||
}
|
||||
|
||||
func (m *Manager) worker() {
|
||||
defer m.wg.Done()
|
||||
for j := range m.queue {
|
||||
// Check cancellation before starting.
|
||||
j.mu.Lock()
|
||||
cancelled := j.status == StatusCancelled
|
||||
j.mu.Unlock()
|
||||
if cancelled {
|
||||
continue
|
||||
}
|
||||
m.inflight.Add(1)
|
||||
j.mu.Lock()
|
||||
j.status = StatusRunning
|
||||
j.startedAt = time.Now().UTC()
|
||||
j.mu.Unlock()
|
||||
|
||||
resp, err := v2.Run(m.mgr, m.elev, j.req)
|
||||
|
||||
j.mu.Lock()
|
||||
j.completedAt = time.Now().UTC()
|
||||
if err != nil {
|
||||
j.status = StatusFailed
|
||||
j.errStr = err.Error()
|
||||
} else {
|
||||
j.status = StatusComplete
|
||||
j.result = resp
|
||||
}
|
||||
j.mu.Unlock()
|
||||
m.inflight.Add(-1)
|
||||
|
||||
if err == nil {
|
||||
m.metrics.Prediction("async", j.completedAt.Sub(j.startedAt), nil)
|
||||
} else {
|
||||
m.metrics.Prediction("async", j.completedAt.Sub(j.startedAt), err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) evictor() {
|
||||
defer m.wg.Done()
|
||||
ticker := time.NewTicker(m.ttl / 4)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-m.closed:
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.evictExpired()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) evictExpired() {
|
||||
now := time.Now().UTC()
|
||||
m.jobsMu.Lock()
|
||||
defer m.jobsMu.Unlock()
|
||||
for id, j := range m.jobs {
|
||||
j.mu.Lock()
|
||||
expired := !j.completedAt.IsZero() && now.Sub(j.completedAt) > m.ttl
|
||||
j.mu.Unlock()
|
||||
if expired {
|
||||
delete(m.jobs, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue