predictor/internal/datasets/throttle.go
2026-05-18 03:17:17 +09:00

63 lines
1.4 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package datasets
import (
"context"
"sync"
"time"
)
// TokenBucket is a simple bytes-per-second rate limiter.
//
// The bucket is initialised full (capacity = rate × 1 second). Calls to Wait
// block until enough tokens have accumulated.
type TokenBucket struct {
mu sync.Mutex
rate float64 // tokens per second
tokens float64
cap float64
last time.Time
}
// NewTokenBucket returns a TokenBucket emitting at most bytesPerSecond.
// A non-positive rate disables throttling (Wait becomes a no-op).
func NewTokenBucket(bytesPerSecond int64) *TokenBucket {
if bytesPerSecond <= 0 {
return &TokenBucket{rate: 0}
}
r := float64(bytesPerSecond)
return &TokenBucket{rate: r, tokens: r, cap: r, last: time.Now()}
}
// Wait blocks until n tokens are available or ctx is cancelled.
func (t *TokenBucket) Wait(ctx context.Context, n int) error {
if t.rate <= 0 {
return nil
}
want := float64(n)
for {
t.mu.Lock()
now := time.Now()
elapsed := now.Sub(t.last).Seconds()
t.last = now
t.tokens += elapsed * t.rate
if t.tokens > t.cap {
t.tokens = t.cap
}
if t.tokens >= want {
t.tokens -= want
t.mu.Unlock()
return nil
}
// Sleep until we expect enough tokens.
need := want - t.tokens
sleep := time.Duration(need / t.rate * float64(time.Second))
t.mu.Unlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(sleep):
}
}
}