package proxy import ( "bytes" "fmt" "net/http" "net/http/httptest" "sync" "testing" "time" "github.com/stretchr/testify/assert" ) func TestProxyManager_SwapProcessCorrectly(t *testing.T) { config := &Config{ HealthCheckTimeout: 15, Models: map[string]ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), "model2": getTestSimpleResponderConfig("model2"), }, } proxy := New(config) defer proxy.StopProcesses() for _, modelName := range []string{"model1", "model2"} { reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() proxy.HandlerFunc(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), modelName) _, exists := proxy.currentProcesses[ProcessKeyName("", modelName)] assert.True(t, exists, "expected %s key in currentProcesses", modelName) } // make sure there's only one loaded model assert.Len(t, proxy.currentProcesses, 1) } func TestProxyManager_SwapMultiProcess(t *testing.T) { model1 := "path1/model1" model2 := "path2/model2" profileModel1 := ProcessKeyName("test", model1) profileModel2 := ProcessKeyName("test", model2) config := &Config{ HealthCheckTimeout: 15, Models: map[string]ModelConfig{ model1: getTestSimpleResponderConfig("model1"), model2: getTestSimpleResponderConfig("model2"), }, Profiles: map[string][]string{ "test": {model1, model2}, }, } proxy := New(config) defer proxy.StopProcesses() for modelID, requestedModel := range map[string]string{ "model1": profileModel1, "model2": profileModel2, } { reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() proxy.HandlerFunc(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), modelID) } // make sure there's two loaded models assert.Len(t, proxy.currentProcesses, 2) _, exists := proxy.currentProcesses[profileModel1] assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses") _, exists = proxy.currentProcesses[profileModel2] assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses") } // When a request for a different model comes in ProxyManager should wait until // the first request is complete before swapping. Both requests should complete func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { if testing.Short() { t.Skip("skipping slow test") } config := &Config{ HealthCheckTimeout: 15, Models: map[string]ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), "model2": getTestSimpleResponderConfig("model2"), "model3": getTestSimpleResponderConfig("model3"), }, } proxy := New(config) defer proxy.StopProcesses() results := map[string]string{} var wg sync.WaitGroup var mu sync.Mutex for key := range config.Models { wg.Add(1) go func(key string) { defer wg.Done() reqBody := fmt.Sprintf(`{"model":"%s"}`, key) req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() proxy.HandlerFunc(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status OK, got %d for key %s", w.Code, key) } mu.Lock() results[key] = w.Body.String() mu.Unlock() }(key) <-time.After(time.Millisecond) } wg.Wait() assert.Len(t, results, len(config.Models)) for key, result := range results { assert.Equal(t, key, result) } }