Files
llama-swap/proxy/proxymanager_test.go
Benson Wong 04b4760e7e change profile split character to : (colon) (#21)
- change from `/` to `:` for multiple models loaded as part of a profile
- breaking change now, but allows for more compatibility with other inference engines that may have model references like `coding:Qwen/Qwen-2.5-Coder-32B`
2024-12-01 09:10:50 -08:00

144 lines
3.6 KiB
Go

package proxy
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
}
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
model1: getTestSimpleResponderConfig("model1"),
model2: getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {model1, model2},
},
}
proxy := New(config)
defer proxy.StopProcesses()
for modelID, requestedModel := range map[string]string{
"model1": profileModel1,
"model2": profileModel2,
} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelID)
}
// make sure there's two loaded models
assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses[profileModel1]
assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
_, exists = proxy.currentProcesses[profileModel2]
assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
}
// When a request for a different model comes in ProxyManager should wait until
// the first request is complete before swapping. Both requests should complete
func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
results := map[string]string{}
var wg sync.WaitGroup
var mu sync.Mutex
for key := range config.Models {
wg.Add(1)
go func(key string) {
defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
}
mu.Lock()
results[key] = w.Body.String()
mu.Unlock()
}(key)
<-time.After(time.Millisecond)
}
wg.Wait()
assert.Len(t, results, len(config.Models))
for key, result := range results {
assert.Equal(t, key, result)
}
}