63 lines
1.4 KiB
Go
63 lines
1.4 KiB
Go
package datasets
|
||
|
||
import (
|
||
"context"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
// TokenBucket is a simple bytes-per-second rate limiter.
|
||
//
|
||
// The bucket is initialised full (capacity = rate × 1 second). Calls to Wait
|
||
// block until enough tokens have accumulated.
|
||
type TokenBucket struct {
|
||
mu sync.Mutex
|
||
rate float64 // tokens per second
|
||
tokens float64
|
||
cap float64
|
||
last time.Time
|
||
}
|
||
|
||
// NewTokenBucket returns a TokenBucket emitting at most bytesPerSecond.
|
||
// A non-positive rate disables throttling (Wait becomes a no-op).
|
||
func NewTokenBucket(bytesPerSecond int64) *TokenBucket {
|
||
if bytesPerSecond <= 0 {
|
||
return &TokenBucket{rate: 0}
|
||
}
|
||
r := float64(bytesPerSecond)
|
||
return &TokenBucket{rate: r, tokens: r, cap: r, last: time.Now()}
|
||
}
|
||
|
||
// Wait blocks until n tokens are available or ctx is cancelled.
|
||
func (t *TokenBucket) Wait(ctx context.Context, n int) error {
|
||
if t.rate <= 0 {
|
||
return nil
|
||
}
|
||
want := float64(n)
|
||
|
||
for {
|
||
t.mu.Lock()
|
||
now := time.Now()
|
||
elapsed := now.Sub(t.last).Seconds()
|
||
t.last = now
|
||
t.tokens += elapsed * t.rate
|
||
if t.tokens > t.cap {
|
||
t.tokens = t.cap
|
||
}
|
||
if t.tokens >= want {
|
||
t.tokens -= want
|
||
t.mu.Unlock()
|
||
return nil
|
||
}
|
||
// Sleep until we expect enough tokens.
|
||
need := want - t.tokens
|
||
sleep := time.Duration(need / t.rate * float64(time.Second))
|
||
t.mu.Unlock()
|
||
|
||
select {
|
||
case <-ctx.Done():
|
||
return ctx.Err()
|
||
case <-time.After(sleep):
|
||
}
|
||
}
|
||
}
|