Add a concurrency limit to Process.ProxyRequest (#123)
This commit is contained in:
@@ -23,6 +23,9 @@ type ModelConfig struct {
|
|||||||
UnloadAfter int `yaml:"ttl"`
|
UnloadAfter int `yaml:"ttl"`
|
||||||
Unlisted bool `yaml:"unlisted"`
|
Unlisted bool `yaml:"unlisted"`
|
||||||
UseModelName string `yaml:"useModelName"`
|
UseModelName string `yaml:"useModelName"`
|
||||||
|
|
||||||
|
// Limit concurrency of HTTP requests to process
|
||||||
|
ConcurrencyLimit int `yaml:"concurrencyLimit"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
|
|||||||
@@ -57,10 +57,19 @@ type Process struct {
|
|||||||
// for managing shutdown state
|
// for managing shutdown state
|
||||||
shutdownCtx context.Context
|
shutdownCtx context.Context
|
||||||
shutdownCancel context.CancelFunc
|
shutdownCancel context.CancelFunc
|
||||||
|
|
||||||
|
// for managing concurrency limits
|
||||||
|
concurrencyLimitSemaphore chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
concurrentLimit := 10
|
||||||
|
if config.ConcurrencyLimit > 0 {
|
||||||
|
concurrentLimit = config.ConcurrencyLimit
|
||||||
|
} else {
|
||||||
|
proxyLogger.Debugf("Concurrency limit for model %s not set, defaulting to 10", ID)
|
||||||
|
}
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
@@ -73,6 +82,9 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo
|
|||||||
state: StateStopped,
|
state: StateStopped,
|
||||||
shutdownCtx: ctx,
|
shutdownCtx: ctx,
|
||||||
shutdownCancel: cancel,
|
shutdownCancel: cancel,
|
||||||
|
|
||||||
|
// concurrency limit
|
||||||
|
concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -417,6 +429,14 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case p.concurrencyLimitSemaphore <- struct{}{}:
|
||||||
|
defer func() { <-p.concurrencyLimitSemaphore }()
|
||||||
|
default:
|
||||||
|
http.Error(w, "Too many requests", http.StatusTooManyRequests)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
p.inFlightRequests.Add(1)
|
p.inFlightRequests.Add(1)
|
||||||
defer func() {
|
defer func() {
|
||||||
p.lastRequestHandled = time.Now()
|
p.lastRequestHandled = time.Now()
|
||||||
|
|||||||
@@ -340,3 +340,35 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
|
|||||||
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
assert.Equal(t, "upstream command exited prematurely but successfully", err.Error())
|
||||||
assert.Equal(t, process.CurrentState(), StateFailed)
|
assert.Equal(t, process.CurrentState(), StateFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_ConcurrencyLimit(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping long concurrency limit test")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMessage := "concurrency_limit_test"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
|
// only allow 1 concurrent request at a time
|
||||||
|
config.ConcurrencyLimit = 1
|
||||||
|
|
||||||
|
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
|
||||||
|
assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore))
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
// launch a goroutine first to take up the semaphore
|
||||||
|
go func() {
|
||||||
|
req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req1)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// let the goroutine start
|
||||||
|
<-time.After(time.Millisecond * 25)
|
||||||
|
|
||||||
|
denied := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, denied)
|
||||||
|
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user