proxy/config: create config package and migrate configuration (#329)

* proxy/config: create config package and migrate configuration

The configuration is become more complex as llama-swap adds more
advanced features. This commit moves config to its own package so it can
be developed independently of the proxy package.

Additionally, enforcing a public API for a configuration will allow
downstream usage to be more decoupled.
This commit is contained in:
Benson Wong
2025-09-28 16:50:06 -07:00
committed by GitHub
parent 9e3d491c85
commit 216c40b951
14 changed files with 108 additions and 97 deletions

View File

@@ -24,10 +24,11 @@ proxy/ui_dist/placeholder.txt:
touch $@ touch $@
test: proxy/ui_dist/placeholder.txt test: proxy/ui_dist/placeholder.txt
go test -short -v -count=1 ./proxy go test -short -count=1 ./proxy/...
# for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt test-all: proxy/ui_dist/placeholder.txt
go test -v -count=1 ./proxy go test -count=1 ./proxy/...
ui/node_modules: ui/node_modules:
cd ui && npm install cd ui && npm install
@@ -81,4 +82,4 @@ release:
git tag "$$new_tag"; git tag "$$new_tag";
# Phony targets # Phony targets
.PHONY: all clean ui mac linux windows simple-responder .PHONY: all clean ui mac linux windows simple-responder test test-all

View File

@@ -16,6 +16,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy" "github.com/mostlygeek/llama-swap/proxy"
"github.com/mostlygeek/llama-swap/proxy/config"
) )
var ( var (
@@ -38,13 +39,13 @@ func main() {
os.Exit(0) os.Exit(0)
} }
config, err := proxy.LoadConfig(*configPath) conf, err := config.LoadConfig(*configPath)
if err != nil { if err != nil {
fmt.Printf("Error loading config: %v\n", err) fmt.Printf("Error loading config: %v\n", err)
os.Exit(1) os.Exit(1)
} }
if len(config.Profiles) > 0 { if len(conf.Profiles) > 0 {
fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.") fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.")
} }
@@ -67,7 +68,7 @@ func main() {
// Support for watching config and reloading when it changes // Support for watching config and reloading when it changes
reloadProxyManager := func() { reloadProxyManager := func() {
if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok { if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok {
config, err = proxy.LoadConfig(*configPath) conf, err = config.LoadConfig(*configPath)
if err != nil { if err != nil {
fmt.Printf("Warning, unable to reload configuration: %v\n", err) fmt.Printf("Warning, unable to reload configuration: %v\n", err)
return return
@@ -75,7 +76,7 @@ func main() {
fmt.Println("Configuration Changed") fmt.Println("Configuration Changed")
currentPM.Shutdown() currentPM.Shutdown()
srv.Handler = proxy.New(config) srv.Handler = proxy.New(conf)
fmt.Println("Configuration Reloaded") fmt.Println("Configuration Reloaded")
// wait a few seconds and tell any UI to reload // wait a few seconds and tell any UI to reload
@@ -85,12 +86,12 @@ func main() {
}) })
}) })
} else { } else {
config, err = proxy.LoadConfig(*configPath) conf, err = config.LoadConfig(*configPath)
if err != nil { if err != nil {
fmt.Printf("Error, unable to load configuration: %v\n", err) fmt.Printf("Error, unable to load configuration: %v\n", err)
os.Exit(1) os.Exit(1)
} }
srv.Handler = proxy.New(config) srv.Handler = proxy.New(conf)
} }
} }

View File

@@ -1,4 +1,4 @@
package proxy package config
import ( import (
"fmt" "fmt"

View File

@@ -1,6 +1,6 @@
//go:build !windows //go:build !windows
package proxy package config
import ( import (
"os" "os"

View File

@@ -1,4 +1,4 @@
package proxy package config
import ( import (
"slices" "slices"

View File

@@ -1,6 +1,6 @@
//go:build windows //go:build windows
package proxy package config
import ( import (
"os" "os"

View File

@@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy/config"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -65,18 +66,18 @@ func getTestPort() int {
return port return port
} }
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig { func getTestSimpleResponderConfig(expectedMessage string) config.ModelConfig {
return getTestSimpleResponderConfigPort(expectedMessage, getTestPort()) return getTestSimpleResponderConfigPort(expectedMessage, getTestPort())
} }
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig { func getTestSimpleResponderConfigPort(expectedMessage string, port int) config.ModelConfig {
// Create a YAML string with just the values we want to set // Create a YAML string with just the values we want to set
yamlStr := fmt.Sprintf(` yamlStr := fmt.Sprintf(`
cmd: '%s --port %d --silent --respond %s' cmd: '%s --port %d --silent --respond %s'
proxy: "http://127.0.0.1:%d" proxy: "http://127.0.0.1:%d"
`, simpleResponderPath, port, expectedMessage, port) `, simpleResponderPath, port, expectedMessage, port)
var cfg ModelConfig var cfg config.ModelConfig
if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil { if err := yaml.Unmarshal([]byte(yamlStr), &cfg); err != nil {
panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr)) panic(fmt.Sprintf("failed to unmarshal test config: %v in [%s]", err, yamlStr))
} }

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
) )
// TokenMetrics represents parsed token statistics from llama-server logs // TokenMetrics represents parsed token statistics from llama-server logs
@@ -38,7 +39,7 @@ type MetricsMonitor struct {
nextID int nextID int
} }
func NewMetricsMonitor(config *Config) *MetricsMonitor { func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
maxMetrics := config.MetricsMaxInMemory maxMetrics := config.MetricsMaxInMemory
if maxMetrics <= 0 { if maxMetrics <= 0 {
maxMetrics = 1000 // Default fallback maxMetrics = 1000 // Default fallback

View File

@@ -16,6 +16,7 @@ import (
"time" "time"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
) )
type ProcessState string type ProcessState string
@@ -39,7 +40,7 @@ const (
type Process struct { type Process struct {
ID string ID string
config ModelConfig config config.ModelConfig
cmd *exec.Cmd cmd *exec.Cmd
// PR #155 called to cancel the upstream process // PR #155 called to cancel the upstream process
@@ -74,7 +75,7 @@ type Process struct {
failedStartCount int failedStartCount int
} }
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, processLogger *LogMonitor, proxyLogger *LogMonitor) *Process {
concurrentLimit := 10 concurrentLimit := 10
if config.ConcurrencyLimit > 0 { if config.ConcurrencyLimit > 0 {
concurrentLimit = config.ConcurrencyLimit concurrentLimit = config.ConcurrencyLimit
@@ -539,7 +540,7 @@ func (p *Process) cmdStopUpstreamProcess() error {
if p.config.CmdStop != "" { if p.config.CmdStop != "" {
// replace ${PID} with the pid of the process // replace ${PID} with the pid of the process
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid))) stopArgs, err := config.SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
if err != nil { if err != nil {
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err) p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
return err return err

View File

@@ -10,6 +10,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -90,7 +91,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
// test that the automatic start returns the expected error type // test that the automatic start returns the expected error type
func TestProcess_BrokenModelConfig(t *testing.T) { func TestProcess_BrokenModelConfig(t *testing.T) {
// Create a process configuration // Create a process configuration
config := ModelConfig{ config := config.ModelConfig{
Cmd: "nonexistent-command", Cmd: "nonexistent-command",
Proxy: "http://127.0.0.1:9913", Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health", CheckEndpoint: "/health",
@@ -325,7 +326,7 @@ func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
// should run and exit but interrupt the long checkHealthTimeout // should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5 checkHealthTimeout := 5
config := ModelConfig{ config := config.ModelConfig{
Cmd: "sleep 1", Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913", Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health", CheckEndpoint: "/health",
@@ -402,7 +403,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
binaryPath := getSimpleResponderPath() binaryPath := getSimpleResponderPath()
port := getTestPort() port := getTestPort()
config := ModelConfig{ conf := config.ModelConfig{
// note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent // note --ignore-sig-term which ignores the SIGTERM signal so a SIGKILL must be sent
// to force the process to exit // to force the process to exit
Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage), Cmd: fmt.Sprintf("%s --port %d --respond %s --silent --ignore-sig-term", binaryPath, port, expectedMessage),
@@ -410,7 +411,7 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
CheckEndpoint: "/health", CheckEndpoint: "/health",
} }
process := NewProcess("stop_immediate", 2, config, debugLogger, debugLogger) process := NewProcess("stop_immediate", 2, conf, debugLogger, debugLogger)
defer process.Stop() defer process.Stop()
// reduce to make testing go faster // reduce to make testing go faster
@@ -450,15 +451,15 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
} }
func TestProcess_StopCmd(t *testing.T) { func TestProcess_StopCmd(t *testing.T) {
config := getTestSimpleResponderConfig("test_stop_cmd") conf := getTestSimpleResponderConfig("test_stop_cmd")
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
config.CmdStop = "taskkill /f /t /pid ${PID}" conf.CmdStop = "taskkill /f /t /pid ${PID}"
} else { } else {
config.CmdStop = "kill -TERM ${PID}" conf.CmdStop = "kill -TERM ${PID}"
} }
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger) process := NewProcess("testStopCmd", 2, conf, debugLogger, debugLogger)
defer process.Stop() defer process.Stop()
err := process.start() err := process.start()
@@ -470,15 +471,15 @@ func TestProcess_StopCmd(t *testing.T) {
func TestProcess_EnvironmentSetCorrectly(t *testing.T) { func TestProcess_EnvironmentSetCorrectly(t *testing.T) {
expectedMessage := "test_env_not_emptied" expectedMessage := "test_env_not_emptied"
config := getTestSimpleResponderConfig(expectedMessage) conf := getTestSimpleResponderConfig(expectedMessage)
// ensure that the the default config does not blank out the inherited environment // ensure that the the default config does not blank out the inherited environment
configWEnv := config configWEnv := conf
// ensure the additiona variables are appended to the process' environment // ensure the additiona variables are appended to the process' environment
configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2") configWEnv.Env = append(configWEnv.Env, "TEST_ENV1=1", "TEST_ENV2=2")
process1 := NewProcess("env_test", 2, config, debugLogger, debugLogger) process1 := NewProcess("env_test", 2, conf, debugLogger, debugLogger)
process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger) process2 := NewProcess("env_test", 2, configWEnv, debugLogger, debugLogger)
process1.start() process1.start()

View File

@@ -5,12 +5,14 @@ import (
"net/http" "net/http"
"slices" "slices"
"sync" "sync"
"github.com/mostlygeek/llama-swap/proxy/config"
) )
type ProcessGroup struct { type ProcessGroup struct {
sync.Mutex sync.Mutex
config Config config config.Config
id string id string
swap bool swap bool
exclusive bool exclusive bool
@@ -24,7 +26,7 @@ type ProcessGroup struct {
lastUsedProcess string lastUsedProcess string
} }
func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup { func NewProcessGroup(id string, config config.Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup {
groupConfig, ok := config.Groups[id] groupConfig, ok := config.Groups[id]
if !ok { if !ok {
panic("Unable to find configuration for group id: " + id) panic("Unable to find configuration for group id: " + id)

View File

@@ -7,19 +7,20 @@ import (
"sync" "sync"
"testing" "testing"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
var processGroupTestConfig = AddDefaultGroupToConfig(Config{ var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
"model4": getTestSimpleResponderConfig("model4"), "model4": getTestSimpleResponderConfig("model4"),
"model5": getTestSimpleResponderConfig("model5"), "model5": getTestSimpleResponderConfig("model5"),
}, },
Groups: map[string]GroupConfig{ Groups: map[string]config.GroupConfig{
"G1": { "G1": {
Swap: true, Swap: true,
Exclusive: true, Exclusive: true,
@@ -34,7 +35,7 @@ var processGroupTestConfig = AddDefaultGroupToConfig(Config{
}) })
func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) { func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) {
pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger) pg := NewProcessGroup(config.DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger)
assert.True(t, pg.HasMember("model5")) assert.True(t, pg.HasMember("model5"))
} }
@@ -48,9 +49,9 @@ func TestProcessGroup_HasMember(t *testing.T) {
// TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true // TestProcessGroup_ProxyRequestSwapIsTrueParallel tests that when swap is true
// and multiple requests are made in parallel, only one process is running at a time. // and multiple requests are made in parallel, only one process is running at a time.
func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) { func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
var processGroupTestConfig = AddDefaultGroupToConfig(Config{ var processGroupTestConfig = config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
// use the same listening so if a model is already running, it will fail // use the same listening so if a model is already running, it will fail
// this is a way to test that swap isolation is working // this is a way to test that swap isolation is working
// properly when there are parallel requests made at the // properly when there are parallel requests made at the
@@ -61,7 +62,7 @@ func TestProcessGroup_ProxyRequestSwapIsTrueParallel(t *testing.T) {
"model4": getTestSimpleResponderConfigPort("model4", 9832), "model4": getTestSimpleResponderConfigPort("model4", 9832),
"model5": getTestSimpleResponderConfigPort("model5", 9832), "model5": getTestSimpleResponderConfigPort("model5", 9832),
}, },
Groups: map[string]GroupConfig{ Groups: map[string]config.GroupConfig{
"G1": { "G1": {
Swap: true, Swap: true,
Members: []string{"model1", "model2", "model3", "model4", "model5"}, Members: []string{"model1", "model2", "model3", "model4", "model5"},

View File

@@ -16,6 +16,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -27,7 +28,7 @@ const (
type ProxyManager struct { type ProxyManager struct {
sync.Mutex sync.Mutex
config Config config config.Config
ginEngine *gin.Engine ginEngine *gin.Engine
// logging // logging
@@ -44,7 +45,7 @@ type ProxyManager struct {
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
} }
func New(config Config) *ProxyManager { func New(config config.Config) *ProxyManager {
// set up loggers // set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout) stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger) upstreamLogger := NewLogMonitorWriter(stdoutLogger)

View File

@@ -16,14 +16,15 @@ import (
"time" "time"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
func TestProxyManager_SwapProcessCorrectly(t *testing.T) { func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
@@ -44,14 +45,14 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
} }
} }
func TestProxyManager_SwapMultiProcess(t *testing.T) { func TestProxyManager_SwapMultiProcess(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "error", LogLevel: "error",
Groups: map[string]GroupConfig{ Groups: map[string]config.GroupConfig{
"G1": { "G1": {
Swap: true, Swap: true,
Exclusive: false, Exclusive: false,
@@ -89,14 +90,14 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
// Test that a persistent group is not affected by the swapping behaviour of // Test that a persistent group is not affected by the swapping behaviour of
// other groups. // other groups.
func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) { func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), // goes into the default group "model1": getTestSimpleResponderConfig("model1"), // goes into the default group
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
LogLevel: "error", LogLevel: "error",
Groups: map[string]GroupConfig{ Groups: map[string]config.GroupConfig{
// the forever group is persistent and should not be affected by model1 // the forever group is persistent and should not be affected by model1
"forever": { "forever": {
Swap: true, Swap: true,
@@ -133,9 +134,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
t.Skip("skipping slow test") t.Skip("skipping slow test")
} }
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
@@ -196,9 +197,9 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
model2Config.Name = " " // empty whitespace only strings will get ignored model2Config.Name = " " // empty whitespace only strings will get ignored
model2Config.Description = " " model2Config.Description = " "
config := Config{ config := config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": model1Config, "model1": model1Config,
"model2": model2Config, "model2": model2Config,
"model3": getTestSimpleResponderConfig("model3"), "model3": getTestSimpleResponderConfig("model3"),
@@ -283,13 +284,13 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) { func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
// Intentionally add models in non-sorted order and with an unlisted model // Intentionally add models in non-sorted order and with an unlisted model
config := Config{ config := config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"zeta": getTestSimpleResponderConfig("zeta"), "zeta": getTestSimpleResponderConfig("zeta"),
"alpha": getTestSimpleResponderConfig("alpha"), "alpha": getTestSimpleResponderConfig("alpha"),
"beta": getTestSimpleResponderConfig("beta"), "beta": getTestSimpleResponderConfig("beta"),
"hidden": func() ModelConfig { "hidden": func() config.ModelConfig {
mc := getTestSimpleResponderConfig("hidden") mc := getTestSimpleResponderConfig("hidden")
mc.Unlisted = true mc.Unlisted = true
return mc return mc
@@ -337,15 +338,15 @@ func TestProxyManager_Shutdown(t *testing.T) {
model3Config := getTestSimpleResponderConfigPort("model3", 9993) model3Config := getTestSimpleResponderConfigPort("model3", 9993)
model3Config.Proxy = "http://localhost:10003/" model3Config.Proxy = "http://localhost:10003/"
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": model1Config, "model1": model1Config,
"model2": model2Config, "model2": model2Config,
"model3": model3Config, "model3": model3Config,
}, },
LogLevel: "error", LogLevel: "error",
Groups: map[string]GroupConfig{ Groups: map[string]config.GroupConfig{
"test": { "test": {
Swap: false, Swap: false,
Members: []string{"model1", "model2", "model3"}, Members: []string{"model1", "model2", "model3"},
@@ -380,21 +381,21 @@ func TestProxyManager_Shutdown(t *testing.T) {
} }
func TestProxyManager_Unload(t *testing.T) { func TestProxyManager_Unload(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ conf := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogLevel: "error", LogLevel: "error",
}) })
proxy := New(config) proxy := New(conf)
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
req = httptest.NewRequest("GET", "/unload", nil) req = httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
@@ -402,23 +403,23 @@ func TestProxyManager_Unload(t *testing.T) {
assert.Equal(t, w.Body.String(), "OK") assert.Equal(t, w.Body.String(), "OK")
select { select {
case <-proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan: case <-proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
// good // good
case <-time.After(2 * time.Second): case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for model1 to stop") t.Fatal("timeout waiting for model1 to stop")
} }
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped) assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
} }
func TestProxyManager_UnloadSingleModel(t *testing.T) { func TestProxyManager_UnloadSingleModel(t *testing.T) {
const testGroupId = "testGroup" const testGroupId = "testGroup"
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
Groups: map[string]GroupConfig{ Groups: map[string]config.GroupConfig{
testGroupId: { testGroupId: {
Swap: false, Swap: false,
Members: []string{"model1", "model2"}, Members: []string{"model1", "model2"},
@@ -463,9 +464,9 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
// Test issue #61 `Listing the current list of models and the loaded model.` // Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) { func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration // Shared configuration
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), "model2": getTestSimpleResponderConfig("model2"),
}, },
@@ -528,9 +529,9 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
} }
func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
}, },
LogLevel: "error", LogLevel: "error",
@@ -581,15 +582,15 @@ func TestProxyManager_UseModelName(t *testing.T) {
modelConfig := getTestSimpleResponderConfig(upstreamModelName) modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName modelConfig.UseModelName = upstreamModelName
config := AddDefaultGroupToConfig(Config{ conf := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": modelConfig, "model1": modelConfig,
}, },
LogLevel: "error", LogLevel: "error",
}) })
proxy := New(config) proxy := New(conf)
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses(StopWaitForInflightRequest)
requestedModel := "model1" requestedModel := "model1"
@@ -644,9 +645,9 @@ func TestProxyManager_UseModelName(t *testing.T) {
} }
func TestProxyManager_CORSOptionsHandler(t *testing.T) { func TestProxyManager_CORSOptionsHandler(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogLevel: "error", LogLevel: "error",
@@ -720,7 +721,7 @@ models:
aliases: [model-alias] aliases: [model-alias]
`, getSimpleResponderPath()) `, getSimpleResponderPath())
config, err := LoadConfigFromReader(strings.NewReader(configStr)) config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err) assert.NoError(t, err)
proxy := New(config) proxy := New(config)
@@ -743,9 +744,9 @@ models:
} }
func TestProxyManager_ChatContentLength(t *testing.T) { func TestProxyManager_ChatContentLength(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogLevel: "error", LogLevel: "error",
@@ -768,14 +769,14 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
func TestProxyManager_FiltersStripParams(t *testing.T) { func TestProxyManager_FiltersStripParams(t *testing.T) {
modelConfig := getTestSimpleResponderConfig("model1") modelConfig := getTestSimpleResponderConfig("model1")
modelConfig.Filters = ModelFilters{ modelConfig.Filters = config.ModelFilters{
StripParams: "temperature, model, stream", StripParams: "temperature, model, stream",
} }
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
LogLevel: "error", LogLevel: "error",
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": modelConfig, "model1": modelConfig,
}, },
}) })
@@ -801,9 +802,9 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
} }
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) { func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogLevel: "error", LogLevel: "error",
@@ -836,9 +837,9 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
} }
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) { func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogLevel: "error", LogLevel: "error",
@@ -871,9 +872,9 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
} }
func TestProxyManager_HealthEndpoint(t *testing.T) { func TestProxyManager_HealthEndpoint(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogLevel: "error", LogLevel: "error",
@@ -890,9 +891,9 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
// Ensure the custom llama-server /completion endpoint proxies correctly // Ensure the custom llama-server /completion endpoint proxies correctly
func TestProxyManager_CompletionEndpoint(t *testing.T) { func TestProxyManager_CompletionEndpoint(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogLevel: "error", LogLevel: "error",
@@ -936,7 +937,7 @@ models:
`, "${simpleresponderpath}", simpleResponderPath, -1) `, "${simpleresponderpath}", simpleResponderPath, -1)
// Create a test model configuration // Create a test model configuration
config, err := LoadConfigFromReader(strings.NewReader(configStr)) config, err := config.LoadConfigFromReader(strings.NewReader(configStr))
if !assert.NoError(t, err, "Invalid configuration") { if !assert.NoError(t, err, "Invalid configuration") {
return return
} }
@@ -970,9 +971,9 @@ models:
} }
func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) { func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), "model1": getTestSimpleResponderConfig("model1"),
}, },
LogLevel: "error", LogLevel: "error",
@@ -1009,9 +1010,9 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
} }
func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) { func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]config.ModelConfig{
"streaming-model": getTestSimpleResponderConfig("streaming-model"), "streaming-model": getTestSimpleResponderConfig("streaming-model"),
}, },
LogLevel: "error", LogLevel: "error",