Files
llama-swap/proxy/proxymanager_test.go
Benson Wong cf82b3c633 Improve Concurrency and Parallel Request Handling (#19)
Rewrite the swap behaviour so that in-flight requests block process swapping until they are completed. 

Additionally: 

- add tests for parallel requests with proxy.ProxyManager and proxy.Process
- improve Process startup behaviour and simplified the code 
- stopping of processes are sent SIGTERM and have 5 seconds to terminate, before they are killed
2024-11-30 15:24:42 -08:00

134 lines
3.4 KiB
Go

package proxy
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses["/"+modelName]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
}
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
}
proxy := New(config)
defer proxy.StopProcesses()
for modelID, requestedModel := range map[string]string{"model1": "test/model1", "model2": "test/model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelID)
}
// make sure there's two loaded models
assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses["test/model1"]
assert.True(t, exists, "expected test/model1 key in currentProcesses")
_, exists = proxy.currentProcesses["test/model2"]
assert.True(t, exists, "expected test/model2 key in currentProcesses")
}
// When a request for a different model comes in ProxyManager should wait until
// the first request is complete before swapping. Both requests should complete
func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
results := map[string]string{}
var wg sync.WaitGroup
var mu sync.Mutex
for key := range config.Models {
wg.Add(1)
go func(key string) {
defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
}
mu.Lock()
results[key] = w.Body.String()
mu.Unlock()
}(key)
<-time.After(time.Millisecond)
}
wg.Wait()
assert.Len(t, results, len(config.Models))
for key, result := range results {
assert.Equal(t, key, result)
}
}