proxy/config: create config package and migrate configuration (#329)

* proxy/config: create config package and migrate configuration

The configuration is become more complex as llama-swap adds more
advanced features. This commit moves config to its own package so it can
be developed independently of the proxy package.

Additionally, enforcing a public API for a configuration will allow
downstream usage to be more decoupled.
This commit is contained in:
Benson Wong
2025-09-28 16:50:06 -07:00
committed by GitHub
parent 9e3d491c85
commit 216c40b951
14 changed files with 108 additions and 97 deletions

View File

@@ -16,14 +16,15 @@ import (
"time"
"github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert"
"github.com/tidwall/gjson"
)
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
@@ -44,14 +45,14 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
}
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "error",
Groups: map[string]GroupConfig{
Groups: map[string]config.GroupConfig{
"G1": {
Swap: true,
Exclusive: false,
@@ -89,14 +90,14 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
// Test that a persistent group is not affected by the swapping behaviour of
// other groups.
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "error",
Groups: map[string]GroupConfig{
Groups: map[string]config.GroupConfig{
// the forever group is persistent and should not be affected by model1
"forever": {
Swap: true,
@@ -133,9 +134,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
t.Skip("skipping slow test")
}
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
@@ -196,9 +197,9 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
model2Config.Name = " " // empty whitespace only strings will get ignored
model2Config.Description = " "
config := Config{
config := config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": model1Config,
"model2": model2Config,
"model3": getTestSimpleResponderConfig("model3"),
@@ -283,13 +284,13 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
// Intentionally add models in non-sorted order and with an unlisted model
config := Config{
config := config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"zeta": getTestSimpleResponderConfig("zeta"),
"alpha": getTestSimpleResponderConfig("alpha"),
"beta": getTestSimpleResponderConfig("beta"),
"hidden": func() ModelConfig {
"hidden": func() config.ModelConfig {
mc := getTestSimpleResponderConfig("hidden")
mc.Unlisted = true
return mc
@@ -337,15 +338,15 @@ func TestProxyManager_Shutdown(t *testing.T) {
model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/"
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": model1Config,
"model2": model2Config,
"model3": model3Config,
},
LogLevel: "error",
Groups: map[string]GroupConfig{
Groups: map[string]config.GroupConfig{
"test": {
Swap: false,
Members: []string{"model1", "model2", "model3"},
@@ -380,21 +381,21 @@ func TestProxyManager_Shutdown(t *testing.T) {
}
func TestProxyManager_Unload(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
conf := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
proxy := New(conf)
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
req = httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder()
proxy.ServeHTTP(w, req)
@@ -402,23 +403,23 @@ func TestProxyManager_Unload(t *testing.T) {
assert.Equal(t, w.Body.String(), "OK")
select {
case <-proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
case <-proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
// good
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for model1 to stop")
}
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
}
func TestProxyManager_UnloadSingleModel(t *testing.T) {
const testGroupId = "testGroup"
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Groups: map[string]GroupConfig{
Groups: map[string]config.GroupConfig{
testGroupId: {
Swap: false,
Members: []string{"model1", "model2"},
@@ -463,9 +464,9 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
// Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
@@ -528,9 +529,9 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
}
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
LogLevel: "error",
@@ -581,15 +582,15 @@ func TestProxyManager_UseModelName(t *testing.T) {
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := AddDefaultGroupToConfig(Config{
conf := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": modelConfig,
},
LogLevel: "error",
})
proxy := New(config)
proxy := New(conf)
defer proxy.StopProcesses(StopWaitForInflightRequest)
requestedModel := "model1"
@@ -644,9 +645,9 @@ func TestProxyManager_UseModelName(t *testing.T) {
}
func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
@@ -720,7 +721,7 @@ models:
aliases: [model-alias]
`, getSimpleResponderPath())
config, err := LoadConfigFromReader(strings.NewReader(configStr))
config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(config)
@@ -743,9 +744,9 @@ models:
}
func TestProxyManager_ChatContentLength(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
@@ -768,14 +769,14 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
func TestProxyManager_FiltersStripParams(t *testing.T) {
modelConfig := getTestSimpleResponderConfig("model1")
modelConfig.Filters = ModelFilters{
modelConfig.Filters = config.ModelFilters{
StripParams: "temperature, model, stream",
}
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
LogLevel: "error",
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": modelConfig,
},
})
@@ -801,9 +802,9 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
}
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
@@ -836,9 +837,9 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
}
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
@@ -871,9 +872,9 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
}
func TestProxyManager_HealthEndpoint(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
@@ -890,9 +891,9 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
// Ensure the custom llama-server /completion endpoint proxies correctly
func TestProxyManager_CompletionEndpoint(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
@@ -936,7 +937,7 @@ models:
`, "${simpleresponderpath}", simpleResponderPath, -1)
// Create a test model configuration
config, err := LoadConfigFromReader(strings.NewReader(configStr))
config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
if !assert.NoError(t, err, "Invalid configuration") {
return
}
@@ -970,9 +971,9 @@ models:
}
func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
@@ -1009,9 +1010,9 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
}
func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
Models: map[string]config.ModelConfig{
"streaming-model": getTestSimpleResponderConfig("streaming-model"),
},
LogLevel: "error",