115 lines
3.5 KiB
Go
115 lines
3.5 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
|
HealthCheckTimeout: 15,
|
|
Models: map[string]ModelConfig{
|
|
"model1": getTestSimpleResponderConfig("model1"),
|
|
"model2": getTestSimpleResponderConfig("model2"),
|
|
"model3": getTestSimpleResponderConfig("model3"),
|
|
"model4": getTestSimpleResponderConfig("model4"),
|
|
"model5": getTestSimpleResponderConfig("model5"),
|
|
},
|
|
Groups: map[string]GroupConfig{
|
|
"G1": {
|
|
Swap: true,
|
|
Exclusive: true,
|
|
Members: []string{"model1", "model2"},
|
|
},
|
|
"G2": {
|
|
Swap: false,
|
|
Exclusive: true,
|
|
Members: []string{"model3", "model4"},
|
|
},
|
|
},
|
|
})
|
|
|
|
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
|
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
|
assert.True(t, pg.HasMember("model5"))
|
|
}
|
|
|
|
func TestProcessGroup_HasMember(t *testing.T) {
|
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
|
assert.True(t, pg.HasMember("model1"))
|
|
assert.True(t, pg.HasMember("model2"))
|
|
assert.False(t, pg.HasMember("model3"))
|
|
}
|
|
|
|
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
|
|
// and multiple requests are made in parallel, only one process is running at a time.
|
|
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
|
|
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
|
HealthCheckTimeout: 15,
|
|
Models: map[string]ModelConfig{
|
|
// use the same listening so if a model is already running, it will fail
|
|
// this is a way to test that swap isolation is working
|
|
// properly when there are parallel requests made at the
|
|
// same time.
|
|
"model1": getTestSimpleResponderConfigPort("model1", 9832),
|
|
"model2": getTestSimpleResponderConfigPort("model2", 9832),
|
|
"model3": getTestSimpleResponderConfigPort("model3", 9832),
|
|
"model4": getTestSimpleResponderConfigPort("model4", 9832),
|
|
"model5": getTestSimpleResponderConfigPort("model5", 9832),
|
|
},
|
|
Groups: map[string]GroupConfig{
|
|
"G1": {
|
|
Swap: true,
|
|
Members: []string{"model1", "model2", "model3", "model4", "model5"},
|
|
},
|
|
},
|
|
})
|
|
|
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
|
|
|
tests := []string{"model1", "model2", "model3", "model4", "model5"}
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
wg.Add(len(tests))
|
|
for _, modelName := range tests {
|
|
go func(modelName string) {
|
|
defer wg.Done()
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", nil)
|
|
w := httptest.NewRecorder()
|
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), modelName)
|
|
}(modelName)
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
|
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
|
defer pg.StopProcesses(StopWaitForInflightRequest)
|
|
|
|
tests := []string{"model3", "model4"}
|
|
|
|
for _, modelName := range tests {
|
|
t.Run(modelName, func(t *testing.T) {
|
|
reqBody := `{"x", "y"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := httptest.NewRecorder()
|
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), modelName)
|
|
})
|
|
}
|
|
|
|
// make sure all the processes are running
|
|
for _, process := range pg.processes {
|
|
assert.Equal(t, StateReady, process.CurrentState())
|
|
}
|
|
}
|