diff --git a/proxy/processgroup.go b/proxy/processgroup.go index 4f10d0a..cca48c1 100644 --- a/proxy/processgroup.go +++ b/proxy/processgroup.go @@ -60,10 +60,20 @@ func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, if pg.swap { pg.Lock() if pg.lastUsedProcess != modelID { + + // is there something already running? if pg.lastUsedProcess != "" { pg.processes[pg.lastUsedProcess].Stop() } + + // wait for the request to the new model to be fully handled + // and prevent race conditions see issue #277 + pg.processes[modelID].ProxyRequest(writer, request) pg.lastUsedProcess = modelID + + // short circuit and exit + pg.Unlock() + return nil } pg.Unlock() } diff --git a/proxy/processgroup_test.go b/proxy/processgroup_test.go index 8a1ace8..791a5a9 100644 --- a/proxy/processgroup_test.go +++ b/proxy/processgroup_test.go @@ -4,6 +4,7 @@ import ( "bytes" "net/http" "net/http/httptest" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -44,32 +45,49 @@ func TestProcessGroup_HasMember(t *testing.T) { assert.False(t, pg.HasMember("model3")) } -func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) { +// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true +// and multiple requests are made in parallel, only one process is running at a time. +func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { + var processGroupTestConfig = AddDefaultGroupToConfig(Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + // use the same listening so if a model is already running, it will fail + // this is a way to test that swap isolation is working + // properly when there are parallel requests made at the + // same time. + "model1": getTestSimpleResponderConfigPort("model1", 9832), + "model2": getTestSimpleResponderConfigPort("model2", 9832), + "model3": getTestSimpleResponderConfigPort("model3", 9832), + "model4": getTestSimpleResponderConfigPort("model4", 9832), + "model5": getTestSimpleResponderConfigPort("model5", 9832), + }, + Groups: map[string]GroupConfig{ + "G1": { + Swap: true, + Members: []string{"model1", "model2", "model3", "model4", "model5"}, + }, + }, + }) + pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) defer pg.StopProcesses(StopWaitForInflightRequest) - tests := []string{"model1", "model2"} + tests := []string{"model1", "model2", "model3", "model4", "model5"} + var wg sync.WaitGroup + + wg.Add(len(tests)) for _, modelName := range tests { - t.Run(modelName, func(t *testing.T) { - reqBody := `{"x", "y"}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + go func(modelName string) { + defer wg.Done() + req := httptest.NewRequest("POST", "/v1/chat/completions", nil) w := httptest.NewRecorder() - assert.NoError(t, pg.ProxyRequest(modelName, w, req)) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), modelName) - - // make sure only one process is in the running state - count := 0 - for _, process := range pg.processes { - if process.CurrentState() == StateReady { - count++ - } - } - assert.Equal(t, 1, count) - }) + }(modelName) } + wg.Wait() } func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index e632be5..f8ba25f 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -42,7 +42,6 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) { assert.Contains(t, w.Body.String(), modelName) } } - func TestProxyManager_SwapMultiProcess(t *testing.T) { config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15,