Groups allows more control over swapping behaviour when a model is requested. The new groups feature provides three ways to control swapping: within the group, swapping out other groups or keep the models in the group loaded persistently (never swapped out). Closes #96, #99 and #106.
97 lines
2.7 KiB
Go
97 lines
2.7 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
var processGroupTestConfig = AddDefaultGroupToConfig(Config{
|
|
HealthCheckTimeout: 15,
|
|
Models: map[string]ModelConfig{
|
|
"model1": getTestSimpleResponderConfig("model1"),
|
|
"model2": getTestSimpleResponderConfig("model2"),
|
|
"model3": getTestSimpleResponderConfig("model3"),
|
|
"model4": getTestSimpleResponderConfig("model4"),
|
|
"model5": getTestSimpleResponderConfig("model5"),
|
|
},
|
|
Groups: map[string]GroupConfig{
|
|
"G1": {
|
|
Swap: true,
|
|
Exclusive: true,
|
|
Members: []string{"model1", "model2"},
|
|
},
|
|
"G2": {
|
|
Swap: false,
|
|
Exclusive: true,
|
|
Members: []string{"model3", "model4"},
|
|
},
|
|
},
|
|
})
|
|
|
|
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
|
|
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
|
|
assert.True(t, pg.HasMember("model5"))
|
|
}
|
|
|
|
func TestProcessGroup_HasMember(t *testing.T) {
|
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
|
assert.True(t, pg.HasMember("model1"))
|
|
assert.True(t, pg.HasMember("model2"))
|
|
assert.False(t, pg.HasMember("model3"))
|
|
}
|
|
|
|
func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) {
|
|
pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger)
|
|
defer pg.StopProcesses()
|
|
|
|
tests := []string{"model1", "model2"}
|
|
|
|
for _, modelName := range tests {
|
|
t.Run(modelName, func(t *testing.T) {
|
|
reqBody := `{"x", "y"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := httptest.NewRecorder()
|
|
|
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), modelName)
|
|
|
|
// make sure only one process is in the running state
|
|
count := 0
|
|
for _, process := range pg.processes {
|
|
if process.CurrentState() == StateReady {
|
|
count++
|
|
}
|
|
}
|
|
assert.Equal(t, 1, count)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) {
|
|
pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger)
|
|
defer pg.StopProcesses()
|
|
|
|
tests := []string{"model3", "model4"}
|
|
|
|
for _, modelName := range tests {
|
|
t.Run(modelName, func(t *testing.T) {
|
|
reqBody := `{"x", "y"}`
|
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
w := httptest.NewRecorder()
|
|
assert.NoError(t, pg.ProxyRequest(modelName, w, req))
|
|
assert.Equal(t, http.StatusOK, w.Code)
|
|
assert.Contains(t, w.Body.String(), modelName)
|
|
})
|
|
}
|
|
|
|
// make sure all the processes are running
|
|
for _, process := range pg.processes {
|
|
assert.Equal(t, StateReady, process.CurrentState())
|
|
}
|
|
}
|