diff --git a/proxy/config.go b/proxy/config.go index e82ffda..cb5284b 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -23,6 +23,9 @@ type ModelConfig struct { UnloadAfter int `yaml:"ttl"` Unlisted bool `yaml:"unlisted"` UseModelName string `yaml:"useModelName"` + + // Limit concurrency of HTTP requests to process + ConcurrencyLimit int `yaml:"concurrencyLimit"` } func (m *ModelConfig) SanitizedCommand() ([]string, error) { diff --git a/proxy/process.go b/proxy/process.go index 2181aed..4f07fc7 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -57,10 +57,19 @@ type Process struct { // for managing shutdown state shutdownCtx context.Context shutdownCancel context.CancelFunc + + // for managing concurrency limits + concurrencyLimitSemaphore chan struct{} } func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { ctx, cancel := context.WithCancel(context.Background()) + concurrentLimit := 10 + if config.ConcurrencyLimit > 0 { + concurrentLimit = config.ConcurrencyLimit + } else { + proxyLogger.Debugf("Concurrency limit for model %s not set, defaulting to 10", ID) + } return &Process{ ID: ID, config: config, @@ -73,6 +82,9 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo state: StateStopped, shutdownCtx: ctx, shutdownCancel: cancel, + + // concurrency limit + concurrencyLimitSemaphore: make(chan struct{}, concurrentLimit), } } @@ -417,6 +429,14 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { return } + select { + case p.concurrencyLimitSemaphore <- struct{}{}: + defer func() { <-p.concurrencyLimitSemaphore }() + default: + http.Error(w, "Too many requests", http.StatusTooManyRequests) + return + } + p.inFlightRequests.Add(1) defer func() { p.lastRequestHandled = time.Now() diff --git a/proxy/process_test.go b/proxy/process_test.go index edd3357..f45a404 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -340,3 +340,35 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) { assert.Equal(t, "upstream command exited prematurely but successfully", err.Error()) assert.Equal(t, process.CurrentState(), StateFailed) } + +func TestProcess_ConcurrencyLimit(t *testing.T) { + if testing.Short() { + t.Skip("skipping long concurrency limit test") + } + + expectedMessage := "concurrency_limit_test" + config := getTestSimpleResponderConfig(expectedMessage) + + // only allow 1 concurrent request at a time + config.ConcurrencyLimit = 1 + + process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) + assert.Equal(t, 1, cap(process.concurrencyLimitSemaphore)) + defer process.Stop() + + // launch a goroutine first to take up the semaphore + go func() { + req1 := httptest.NewRequest("GET", "/slow-respond?echo=12345&delay=75ms", nil) + w := httptest.NewRecorder() + process.ProxyRequest(w, req1) + assert.Equal(t, http.StatusOK, w.Code) + }() + + // let the goroutine start + <-time.After(time.Millisecond * 25) + + denied := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + process.ProxyRequest(w, denied) + assert.Equal(t, http.StatusTooManyRequests, w.Code) +}