From 448ccae959d71f0a2e9635cae0ecd48890a829d5 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Fri, 2 May 2025 22:35:38 -0700 Subject: [PATCH] Introduce Groups Feature (#107) Groups allows more control over swapping behaviour when a model is requested. The new groups feature provides three ways to control swapping: within the group, swapping out other groups or keep the models in the group loaded persistently (never swapped out). Closes #96, #99 and #106. --- README.md | 62 ++++- llama-swap.go | 4 + proxy/config.go | 104 +++++++- proxy/config_test.go | 97 +++++++- proxy/helpers_test.go | 12 + proxy/process.go | 4 + proxy/processgroup.go | 111 +++++++++ proxy/processgroup_test.go | 96 +++++++ proxy/proxymanager.go | 250 +++++++------------ proxy/proxymanager_test.go | 496 ++++++++++++++----------------------- 10 files changed, 754 insertions(+), 482 deletions(-) create mode 100644 proxy/processgroup.go create mode 100644 proxy/processgroup_test.go diff --git a/README.md b/README.md index e6b571d..28acd8a 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Written in golang, it is very easy to install (single binary with no dependancie - `/upstream/:model_id` - direct access to upstream HTTP server ([demo](https://github.com/mostlygeek/llama-swap/pull/31)) - `/unload` - manually unload running models ([#58](https://github.com/mostlygeek/llama-swap/issues/58)) - `/running` - list currently running models ([#61](https://github.com/mostlygeek/llama-swap/issues/61)) -- ✅ Run multiple models at once with `profiles` ([docs](https://github.com/mostlygeek/llama-swap/issues/53#issuecomment-2660761741)) +- ✅ Run multiple models at once with `Groups` ([#107](https://github.com/mostlygeek/llama-swap/issues/107)) - ✅ Automatic unloading of models after timeout by setting a `ttl` - ✅ Use any local OpenAI compatible server (llama.cpp, vllm, tabbyAPI, etc) - ✅ Docker and Podman support @@ -36,7 +36,7 @@ Written in golang, it is very easy to install (single binary with no dependancie When a request is made to an OpenAI compatible endpoint, lama-swap will extract the `model` value and load the appropriate server configuration to serve it. If the wrong upstream server is running, it will be replaced with the correct one. This is where the "swap" part comes in. The upstream server is automatically swapped to the correct one to serve the request. -In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `profiles` feature can load multiple models at the same time. You have complete control over how your system resources are used. +In the most basic configuration llama-swap handles one model at a time. For more advanced use cases, the `groups` feature allows multiple models to be loaded at the same time. You have complete control over how your system resources are used. ## config.yaml @@ -120,16 +120,58 @@ models: ghcr.io/ggerganov/llama.cpp:server --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' -# profiles eliminates swapping by running multiple models at the same time +# Groups provide advanced controls over model swapping behaviour. Using groups +# some models can be kept loaded indefinitely, while others are swapped out. # # Tips: -# - each model must be listening on a unique address and port -# - the model name is in this format: "profile_name:model", like "coding:qwen" -# - the profile will load and unload all models in the profile at the same time -profiles: - coding: - - "llama" - - "qwen-unlisted" +# +# - models must be defined above in the Models section +# - a model can only be a member of one group +# - group behaviour is controlled via the `swap`, `exclusive` and `persistent` fields +# - see issue #109 for details +# +# NOTE: the example below uses model names that are not defined above for demonstration purposes +groups: + # group1 is the default behaviour of llama-swap where only one model is allowed + # to run a time across the whole llama-swap instance + "group1": + # swap controls the model swapping behaviour in within the group + # - true : only one model is allowed to run at a time + # - false: all models can run together, no swapping + swap: true + + # exclusive controls how the group affects other groups + # - true: causes all other groups to unload their models when this group runs a model + # - false: does not affect other groups + exclusive: true + + # members references the models defined above + members: + - "llama" + - "qwen-unlisted" + + # models in this group are never unloaded + "group2": + swap: false + exclusive: false + members: + - "docker-llama" + # (not defined above, here for example) + - "modelA" + - "modelB" + + "forever": + # setting persistent to true causes the group to never be affected by the swapping behaviour of + # other groups. It is a shortcut to keeping some models always loaded. + persistent: true + + # set swap/exclusive to false to prevent swapping inside the group and effect on other groups + swap: false + exclusive: false + members: + - "forever-modelA" + - "forever-modelB" + - "forever-modelc" ``` ### Use Case Examples diff --git a/llama-swap.go b/llama-swap.go index 7df4336..c570220 100644 --- a/llama-swap.go +++ b/llama-swap.go @@ -34,6 +34,10 @@ func main() { os.Exit(1) } + if len(config.Profiles) > 0 { + fmt.Println("WARNING: Profile functionality has been removed in favor of Groups. See the README for more information.") + } + if mode := os.Getenv("GIN_MODE"); mode != "" { gin.SetMode(mode) } else { diff --git a/proxy/config.go b/proxy/config.go index bef11d9..ee86174 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -3,12 +3,15 @@ package proxy import ( "fmt" "os" + "sort" "strings" "github.com/google/shlex" "gopkg.in/yaml.v3" ) +const DEFAULT_GROUP_ID = "(default)" + type ModelConfig struct { Cmd string `yaml:"cmd"` Proxy string `yaml:"proxy"` @@ -24,12 +27,38 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) { return SanitizeCommand(m.Cmd) } +type GroupConfig struct { + Swap bool `yaml:"swap"` + Exclusive bool `yaml:"exclusive"` + Persistent bool `yaml:"persistent"` + Members []string `yaml:"members"` +} + +// set default values for GroupConfig +func (c *GroupConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + type rawGroupConfig GroupConfig + defaults := rawGroupConfig{ + Swap: true, + Exclusive: true, + Persistent: false, + Members: []string{}, + } + + if err := unmarshal(&defaults); err != nil { + return err + } + + *c = GroupConfig(defaults) + return nil +} + type Config struct { HealthCheckTimeout int `yaml:"healthCheckTimeout"` LogRequests bool `yaml:"logRequests"` LogLevel string `yaml:"logLevel"` - Models map[string]ModelConfig `yaml:"models"` + Models map[string]ModelConfig `yaml:"models"` /* key is model ID */ Profiles map[string][]string `yaml:"profiles"` + Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */ // map aliases to actual model IDs aliases map[string]string @@ -53,16 +82,16 @@ func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) { } } -func LoadConfig(path string) (*Config, error) { +func LoadConfig(path string) (Config, error) { data, err := os.ReadFile(path) if err != nil { - return nil, err + return Config{}, err } var config Config err = yaml.Unmarshal(data, &config) if err != nil { - return nil, err + return Config{}, err } if config.HealthCheckTimeout < 15 { @@ -77,7 +106,72 @@ func LoadConfig(path string) (*Config, error) { } } - return &config, nil + config = AddDefaultGroupToConfig(config) + // check that members are all unique in the groups + memberUsage := make(map[string]string) // maps member to group it appears in + for groupID, groupConfig := range config.Groups { + prevSet := make(map[string]bool) + for _, member := range groupConfig.Members { + // Check for duplicates within this group + if _, found := prevSet[member]; found { + return Config{}, fmt.Errorf("duplicate model member %s found in group: %s", member, groupID) + } + prevSet[member] = true + + // Check if member is used in another group + if existingGroup, exists := memberUsage[member]; exists { + return Config{}, fmt.Errorf("model member %s is used in multiple groups: %s and %s", member, existingGroup, groupID) + } + memberUsage[member] = groupID + } + } + + return config, nil +} + +// rewrites the yaml to include a default group with any orphaned models +func AddDefaultGroupToConfig(config Config) Config { + + if config.Groups == nil { + config.Groups = make(map[string]GroupConfig) + } + + defaultGroup := GroupConfig{ + Swap: true, + Exclusive: true, + Members: []string{}, + } + // if groups is empty, create a default group and put + // all models into it + if len(config.Groups) == 0 { + for modelName := range config.Models { + defaultGroup.Members = append(defaultGroup.Members, modelName) + } + } else { + // iterate over existing group members and add non-grouped models into the default group + for modelName, _ := range config.Models { + foundModel := false + found: + // search for the model in existing groups + for _, groupConfig := range config.Groups { + for _, member := range groupConfig.Members { + if member == modelName { + foundModel = true + break found + } + } + } + + if !foundModel { + defaultGroup.Members = append(defaultGroup.Members, modelName) + } + } + } + + sort.Strings(defaultGroup.Members) // make consistent ordering for testing + config.Groups[DEFAULT_GROUP_ID] = defaultGroup + + return config } func SanitizeCommand(cmdStr string) ([]string, error) { diff --git a/proxy/config_test.go b/proxy/config_test.go index da2eb39..292c87c 100644 --- a/proxy/config_test.go +++ b/proxy/config_test.go @@ -35,11 +35,31 @@ models: aliases: - "m2" checkEndpoint: "/" + model3: + cmd: path/to/cmd --arg1 one + proxy: "http://localhost:8081" + aliases: + - "mthree" + checkEndpoint: "/" + model4: + cmd: path/to/cmd --arg1 one + checkEndpoint: "/" + healthCheckTimeout: 15 profiles: test: - model1 - model2 +groups: + group1: + swap: true + exclusive: false + members: ["model2"] + forever: + exclusive: false + persistent: true + members: + - "model4" ` if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { @@ -52,7 +72,7 @@ profiles: t.Fatalf("Failed to load config: %v", err) } - expected := &Config{ + expected := Config{ Models: map[string]ModelConfig{ "model1": { Cmd: "path/to/cmd --arg1 one", @@ -68,6 +88,17 @@ profiles: Env: nil, CheckEndpoint: "/", }, + "model3": { + Cmd: "path/to/cmd --arg1 one", + Proxy: "http://localhost:8081", + Aliases: []string{"mthree"}, + Env: nil, + CheckEndpoint: "/", + }, + "model4": { + Cmd: "path/to/cmd --arg1 one", + CheckEndpoint: "/", + }, }, HealthCheckTimeout: 15, Profiles: map[string][]string{ @@ -77,6 +108,25 @@ profiles: "m1": "model1", "model-one": "model1", "m2": "model2", + "mthree": "model3", + }, + Groups: map[string]GroupConfig{ + DEFAULT_GROUP_ID: { + Swap: true, + Exclusive: true, + Members: []string{"model1", "model3"}, + }, + "group1": { + Swap: true, + Exclusive: false, + Members: []string{"model2"}, + }, + "forever": { + Swap: true, + Exclusive: false, + Persistent: true, + Members: []string{"model4"}, + }, }, } @@ -87,6 +137,51 @@ profiles: assert.Equal(t, "model1", realname) } +func TestConfig_GroupMemberIsUnique(t *testing.T) { + // Create a temporary YAML file for testing + tempDir, err := os.MkdirTemp("", "test-config") + if err != nil { + t.Fatalf("Failed to create temporary directory: %v", err) + } + defer os.RemoveAll(tempDir) + + tempFile := filepath.Join(tempDir, "config.yaml") + content := ` +models: + model1: + cmd: path/to/cmd --arg1 one + proxy: "http://localhost:8080" + model2: + cmd: path/to/cmd --arg1 one + proxy: "http://localhost:8081" + checkEndpoint: "/" + model3: + cmd: path/to/cmd --arg1 one + proxy: "http://localhost:8081" + checkEndpoint: "/" + +healthCheckTimeout: 15 +groups: + group1: + swap: true + exclusive: false + members: ["model2"] + group2: + swap: true + exclusive: false + members: ["model2"] +` + + if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { + t.Fatalf("Failed to write temporary file: %v", err) + } + + // Load the config and verify + _, err = LoadConfig(tempFile) + assert.NotNil(t, err) + +} + func TestConfig_ModelConfigSanitizedCommand(t *testing.T) { config := &ModelConfig{ Cmd: `python model1.py \ diff --git a/proxy/helpers_test.go b/proxy/helpers_test.go index d6982c5..0082192 100644 --- a/proxy/helpers_test.go +++ b/proxy/helpers_test.go @@ -14,6 +14,7 @@ import ( var ( nextTestPort int = 12000 portMutex sync.Mutex + testLogger = NewLogMonitorWriter(os.Stdout) ) // Check if the binary exists @@ -26,6 +27,17 @@ func TestMain(m *testing.M) { gin.SetMode(gin.TestMode) + switch os.Getenv("LOG_LEVEL") { + case "debug": + testLogger.SetLogLevel(LevelDebug) + case "warn": + testLogger.SetLogLevel(LevelWarn) + case "info": + testLogger.SetLogLevel(LevelInfo) + default: + testLogger.SetLogLevel(LevelWarn) + } + m.Run() } diff --git a/proxy/process.go b/proxy/process.go index f9d17cc..0969d1c 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -301,6 +301,10 @@ func (p *Process) start() error { } func (p *Process) Stop() { + if !isValidTransition(p.CurrentState(), StateStopping) { + return + } + // wait for any inflight requests before proceeding p.inFlightRequests.Wait() p.proxyLogger.Debugf("<%s> Stopping process", p.ID) diff --git a/proxy/processgroup.go b/proxy/processgroup.go new file mode 100644 index 0000000..9c244a4 --- /dev/null +++ b/proxy/processgroup.go @@ -0,0 +1,111 @@ +package proxy + +import ( + "fmt" + "net/http" + "slices" + "sync" +) + +type ProcessGroup struct { + sync.Mutex + + config Config + id string + swap bool + exclusive bool + persistent bool + + proxyLogger *LogMonitor + upstreamLogger *LogMonitor + + // map of current processes + processes map[string]*Process + lastUsedProcess string +} + +func NewProcessGroup(id string, config Config, proxyLogger *LogMonitor, upstreamLogger *LogMonitor) *ProcessGroup { + groupConfig, ok := config.Groups[id] + if !ok { + panic("Unable to find configuration for group id: " + id) + } + + pg := &ProcessGroup{ + id: id, + config: config, + swap: groupConfig.Swap, + exclusive: groupConfig.Exclusive, + persistent: groupConfig.Persistent, + proxyLogger: proxyLogger, + upstreamLogger: upstreamLogger, + processes: make(map[string]*Process), + } + + // Create a Process for each member in the group + for _, modelID := range groupConfig.Members { + modelConfig, modelID, _ := pg.config.FindConfig(modelID) + process := NewProcess(modelID, pg.config.HealthCheckTimeout, modelConfig, pg.upstreamLogger, pg.proxyLogger) + pg.processes[modelID] = process + } + + return pg +} + +// ProxyRequest proxies a request to the specified model +func (pg *ProcessGroup) ProxyRequest(modelID string, writer http.ResponseWriter, request *http.Request) error { + if !pg.HasMember(modelID) { + return fmt.Errorf("model %s not part of group %s", modelID, pg.id) + } + + pg.Lock() + if pg.swap && pg.lastUsedProcess != modelID { + if pg.lastUsedProcess != "" { + pg.processes[pg.lastUsedProcess].Stop() + } + pg.lastUsedProcess = modelID + } + pg.Unlock() + + pg.processes[modelID].ProxyRequest(writer, request) + return nil +} + +func (pg *ProcessGroup) HasMember(modelName string) bool { + return slices.Contains(pg.config.Groups[pg.id].Members, modelName) +} + +func (pg *ProcessGroup) StopProcesses() { + pg.Lock() + defer pg.Unlock() + pg.stopProcesses() +} + +// stopProcesses stops all processes in the group +func (pg *ProcessGroup) stopProcesses() { + if len(pg.processes) == 0 { + return + } + + // stop Processes in parallel + var wg sync.WaitGroup + for _, process := range pg.processes { + wg.Add(1) + go func(process *Process) { + defer wg.Done() + process.Stop() + }(process) + } + wg.Wait() +} + +func (pg *ProcessGroup) Shutdown() { + var wg sync.WaitGroup + for _, process := range pg.processes { + wg.Add(1) + go func(process *Process) { + defer wg.Done() + process.Shutdown() + }(process) + } + wg.Wait() +} diff --git a/proxy/processgroup_test.go b/proxy/processgroup_test.go new file mode 100644 index 0000000..c6d3670 --- /dev/null +++ b/proxy/processgroup_test.go @@ -0,0 +1,96 @@ +package proxy + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +var processGroupTestConfig = AddDefaultGroupToConfig(Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), + "model3": getTestSimpleResponderConfig("model3"), + "model4": getTestSimpleResponderConfig("model4"), + "model5": getTestSimpleResponderConfig("model5"), + }, + Groups: map[string]GroupConfig{ + "G1": { + Swap: true, + Exclusive: true, + Members: []string{"model1", "model2"}, + }, + "G2": { + Swap: false, + Exclusive: true, + Members: []string{"model3", "model4"}, + }, + }, +}) + +func TestProcessGroup_DefaultHasCorrectModel(t *testing.T) { + pg := NewProcessGroup(DEFAULT_GROUP_ID, processGroupTestConfig, testLogger, testLogger) + assert.True(t, pg.HasMember("model5")) +} + +func TestProcessGroup_HasMember(t *testing.T) { + pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) + assert.True(t, pg.HasMember("model1")) + assert.True(t, pg.HasMember("model2")) + assert.False(t, pg.HasMember("model3")) +} + +func TestProcessGroup_ProxyRequestSwapIsTrue(t *testing.T) { + pg := NewProcessGroup("G1", processGroupTestConfig, testLogger, testLogger) + defer pg.StopProcesses() + + tests := []string{"model1", "model2"} + + for _, modelName := range tests { + t.Run(modelName, func(t *testing.T) { + reqBody := `{"x", "y"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + assert.NoError(t, pg.ProxyRequest(modelName, w, req)) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), modelName) + + // make sure only one process is in the running state + count := 0 + for _, process := range pg.processes { + if process.CurrentState() == StateReady { + count++ + } + } + assert.Equal(t, 1, count) + }) + } +} + +func TestProcessGroup_ProxyRequestSwapIsFalse(t *testing.T) { + pg := NewProcessGroup("G2", processGroupTestConfig, testLogger, testLogger) + defer pg.StopProcesses() + + tests := []string{"model3", "model4"} + + for _, modelName := range tests { + t.Run(modelName, func(t *testing.T) { + reqBody := `{"x", "y"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + assert.NoError(t, pg.ProxyRequest(modelName, w, req)) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), modelName) + }) + } + + // make sure all the processes are running + for _, process := range pg.processes { + assert.Equal(t, StateReady, process.CurrentState()) + } +} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index da7b1a9..0b243fd 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -26,17 +26,18 @@ const ( type ProxyManager struct { sync.Mutex - config *Config - currentProcesses map[string]*Process - ginEngine *gin.Engine + config Config + ginEngine *gin.Engine // logging proxyLogger *LogMonitor upstreamLogger *LogMonitor muxLogger *LogMonitor + + processGroups map[string]*ProcessGroup } -func New(config *Config) *ProxyManager { +func New(config Config) *ProxyManager { // set up loggers stdoutLogger := NewLogMonitorWriter(os.Stdout) upstreamLogger := NewLogMonitorWriter(stdoutLogger) @@ -65,13 +66,20 @@ func New(config *Config) *ProxyManager { } pm := &ProxyManager{ - config: config, - currentProcesses: make(map[string]*Process), - ginEngine: gin.New(), + config: config, + ginEngine: gin.New(), proxyLogger: proxyLogger, muxLogger: stdoutLogger, upstreamLogger: upstreamLogger, + + processGroups: make(map[string]*ProcessGroup), + } + + // create the process groups + for groupID := range config.Groups { + processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger) + pm.processGroups[groupID] = processGroup } pm.ginEngine.Use(func(c *gin.Context) { @@ -200,27 +208,17 @@ func (pm *ProxyManager) StopProcesses() { pm.Lock() defer pm.Unlock() - pm.stopProcesses() -} - -// for internal usage -func (pm *ProxyManager) stopProcesses() { - if len(pm.currentProcesses) == 0 { - return - } - // stop Processes in parallel var wg sync.WaitGroup - for _, process := range pm.currentProcesses { + for _, processGroup := range pm.processGroups { wg.Add(1) - go func(process *Process) { + go func(processGroup *ProcessGroup) { defer wg.Done() - process.Stop() - }(process) + processGroup.stopProcesses() + }(processGroup) } - wg.Wait() - pm.currentProcesses = make(map[string]*Process) + wg.Wait() } // Shutdown is called to shutdown all upstream processes @@ -229,18 +227,44 @@ func (pm *ProxyManager) Shutdown() { pm.Lock() defer pm.Unlock() - // shutdown process in parallel + pm.proxyLogger.Debug("Shutdown() called in proxy manager") + var wg sync.WaitGroup - for _, process := range pm.currentProcesses { + // Send shutdown signal to all process in groups + for _, processGroup := range pm.processGroups { wg.Add(1) - go func(process *Process) { + go func(processGroup *ProcessGroup) { defer wg.Done() - process.Shutdown() - }(process) + processGroup.Shutdown() + }(processGroup) } wg.Wait() } +func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) { + // de-alias the real model name and get a real one + realModelName, found := pm.config.RealModelName(requestedModel) + if !found { + return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel) + } + + processGroup := pm.findGroupByModelName(realModelName) + if processGroup == nil { + return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel) + } + + if processGroup.exclusive { + pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id) + for groupId, otherGroup := range pm.processGroups { + if groupId != processGroup.id && !otherGroup.persistent { + otherGroup.StopProcesses() + } + } + } + + return processGroup, realModelName, nil +} + func (pm *ProxyManager) listModelsHandler(c *gin.Context) { data := []interface{}{} for id, modelConfig := range pm.config.Models { @@ -270,79 +294,6 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) { } } -func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) { - pm.Lock() - defer pm.Unlock() - - // Check if requestedModel contains a PROFILE_SPLIT_CHAR - profileName, modelName := splitRequestedModel(requestedModel) - - if profileName != "" { - if _, found := pm.config.Profiles[profileName]; !found { - return nil, fmt.Errorf("model group not found %s", profileName) - } - } - - // de-alias the real model name and get a real one - realModelName, found := pm.config.RealModelName(modelName) - if !found { - return nil, fmt.Errorf("could not find modelID for %s", requestedModel) - } - - // check if model is part of the profile - if profileName != "" { - found := false - for _, item := range pm.config.Profiles[profileName] { - if item == realModelName { - found = true - break - } - } - - if !found { - return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName) - } - } - - // exit early when already running, otherwise stop everything and swap - requestedProcessKey := ProcessKeyName(profileName, realModelName) - - if process, found := pm.currentProcesses[requestedProcessKey]; found { - pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel) - return process, nil - } - - // stop all running models - pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel) - pm.stopProcesses() - if profileName == "" { - modelConfig, modelID, found := pm.config.FindConfig(realModelName) - if !found { - return nil, fmt.Errorf("could not find configuration for %s", realModelName) - } - - process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger) - processKey := ProcessKeyName(profileName, modelID) - pm.currentProcesses[processKey] = process - } else { - for _, modelName := range pm.config.Profiles[profileName] { - if realModelName, found := pm.config.RealModelName(modelName); found { - modelConfig, modelID, found := pm.config.FindConfig(realModelName) - if !found { - return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName) - } - - process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger) - processKey := ProcessKeyName(profileName, modelID) - pm.currentProcesses[processKey] = process - } - } - } - - // requestedProcessKey should exist due to swap - return pm.currentProcesses[requestedProcessKey], nil -} - func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { requestedModel := c.Param("model_id") @@ -351,13 +302,14 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { return } - if process, err := pm.swapModel(requestedModel); err != nil { - pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) - } else { - // rewrite the path - c.Request.URL.Path = c.Param("upstreamPath") - process.ProxyRequest(c.Writer, c.Request) + processGroup, _, err := pm.swapProcessGroup(requestedModel) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) } + + // rewrite the path + c.Request.URL.Path = c.Param("upstreamPath") + processGroup.ProxyRequest(requestedModel, c.Writer, c.Request) } func (pm *ProxyManager) upstreamIndex(c *gin.Context) { @@ -397,29 +349,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") } - process, err := pm.swapModel(requestedModel) - + processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) if err != nil { - pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) - return + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) } // issue #69 allow custom model names to be sent to upstream - if process.config.UseModelName != "" { - bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName) + useModelName := pm.config.Models[realModelName].UseModelName + if useModelName != "" { + bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName) if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error())) + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) return } - } else { - profileName, modelName := splitRequestedModel(requestedModel) - if profileName != "" { - bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName) - if err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error())) - return - } - } } c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -428,8 +370,10 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { c.Request.Header.Del("transfer-encoding") c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) - process.ProxyRequest(c.Writer, c.Request) - + if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) + pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) + } } func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { @@ -451,26 +395,24 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { return } - // Swap to the requested model - process, err := pm.swapModel(requestedModel) + processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) if err != nil { - pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) - return + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) } - // Get profile name and model name from the requested model - profileName, modelName := splitRequestedModel(requestedModel) - // Copy all form values for key, values := range c.Request.MultipartForm.Value { for _, value := range values { fieldValue := value // If this is the model field and we have a profile, use just the model name if key == "model" { - if process.config.UseModelName != "" { - fieldValue = process.config.UseModelName - } else if profileName != "" { - fieldValue = modelName + // # issue #69 allow custom model names to be sent to upstream + useModelName := pm.config.Models[realModelName].UseModelName + + if useModelName != "" { + fieldValue = useModelName + } else { + fieldValue = requestedModel } } field, err := multipartWriter.CreateFormField(key) @@ -532,7 +474,10 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) // Use the modified request for proxying - process.ProxyRequest(c.Writer, modifiedReq) + if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) + pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) + } } func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) { @@ -554,14 +499,15 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) { context.Header("Content-Type", "application/json") runningProcesses := make([]gin.H, 0) // Default to an empty response. - for _, process := range pm.currentProcesses { - - // Append the process ID and State (multiple entries if profiles are being used). - runningProcesses = append(runningProcesses, gin.H{ - "model": process.ID, - "state": process.state, - }) - + for _, processGroup := range pm.processGroups { + for _, process := range processGroup.processes { + if process.CurrentState() == StateReady { + runningProcesses = append(runningProcesses, gin.H{ + "model": process.ID, + "state": process.state, + }) + } + } } // Put the results under the `running` key. @@ -572,15 +518,11 @@ func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) { context.JSON(http.StatusOK, response) // Always return 200 OK } -func ProcessKeyName(groupName, modelName string) string { - return groupName + PROFILE_SPLIT_CHAR + modelName -} - -func splitRequestedModel(requestedModel string) (string, string) { - profileName, modelName := "", requestedModel - if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 { - profileName = requestedModel[:idx] - modelName = requestedModel[idx+1:] +func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup { + for _, group := range pm.processGroups { + if group.HasMember(modelName) { + return group + } } - return profileName, modelName + return nil } diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 2352f48..cddfa40 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -16,14 +16,14 @@ import ( ) func TestProxyManager_SwapProcessCorrectly(t *testing.T) { - config := &Config{ + config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15, Models: map[string]ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), "model2": getTestSimpleResponderConfig("model2"), }, LogLevel: "error", - } + }) proxy := New(config) defer proxy.StopProcesses() @@ -36,59 +36,91 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) { proxy.HandlerFunc(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), modelName) - - _, exists := proxy.currentProcesses[ProcessKeyName("", modelName)] - assert.True(t, exists, "expected %s key in currentProcesses", modelName) - } - - // make sure there's only one loaded model - assert.Len(t, proxy.currentProcesses, 1) } func TestProxyManager_SwapMultiProcess(t *testing.T) { - - model1 := "path1/model1" - model2 := "path2/model2" - - profileModel1 := ProcessKeyName("test", model1) - profileModel2 := ProcessKeyName("test", model2) - - config := &Config{ + config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15, Models: map[string]ModelConfig{ - model1: getTestSimpleResponderConfig("model1"), - model2: getTestSimpleResponderConfig("model2"), - }, - Profiles: map[string][]string{ - "test": {model1, model2}, + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), }, LogLevel: "error", - } + Groups: map[string]GroupConfig{ + "G1": { + Swap: true, + Exclusive: false, + Members: []string{"model1"}, + }, + "G2": { + Swap: true, + Exclusive: false, + Members: []string{"model2"}, + }, + }, + }) proxy := New(config) defer proxy.StopProcesses() - for modelID, requestedModel := range map[string]string{ - "model1": profileModel1, - "model2": profileModel2, - } { + tests := []string{"model1", "model2"} + for _, requestedModel := range tests { + t.Run(requestedModel, func(t *testing.T) { + reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), requestedModel) + + }) + } + + // make sure there's two loaded models + assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady) + assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady) +} + +// Test that a persistent group is not affected by the swapping behaviour of +// other groups. +func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) { + config := AddDefaultGroupToConfig(Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), // goes into the default group + "model2": getTestSimpleResponderConfig("model2"), + }, + LogLevel: "error", + Groups: map[string]GroupConfig{ + // the forever group is persistent and should not be affected by model1 + "forever": { + Swap: true, + Exclusive: false, + Persistent: true, + Members: []string{"model2"}, + }, + }, + }) + + proxy := New(config) + defer proxy.StopProcesses() + + // make requests to load all models, loading model1 should not affect model2 + tests := []string{"model2", "model1"} + for _, requestedModel := range tests { reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() proxy.HandlerFunc(w, req) assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), modelID) + assert.Contains(t, w.Body.String(), requestedModel) } - // make sure there's two loaded models - assert.Len(t, proxy.currentProcesses, 2) - _, exists := proxy.currentProcesses[profileModel1] - assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses") - - _, exists = proxy.currentProcesses[profileModel2] - assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses") + assert.Equal(t, proxy.findGroupByModelName("model2").processes["model2"].CurrentState(), StateReady) + assert.Equal(t, proxy.findGroupByModelName("model1").processes["model1"].CurrentState(), StateReady) } // When a request for a different model comes in ProxyManager should wait until @@ -98,7 +130,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { t.Skip("skipping slow test") } - config := &Config{ + config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15, Models: map[string]ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), @@ -106,7 +138,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { "model3": getTestSimpleResponderConfig("model3"), }, LogLevel: "error", - } + }) proxy := New(config) defer proxy.StopProcesses() @@ -149,7 +181,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { } func TestProxyManager_ListModelsHandler(t *testing.T) { - config := &Config{ + config := Config{ HealthCheckTimeout: 15, Models: map[string]ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), @@ -217,51 +249,6 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { assert.Empty(t, expectedModels, "not all expected models were returned") } -func TestProxyManager_ProfileNonMember(t *testing.T) { - - model1 := "path1/model1" - model2 := "path2/model2" - - profileMemberName := ProcessKeyName("test", model1) - profileNonMemberName := ProcessKeyName("test", model2) - - config := &Config{ - HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ - model1: getTestSimpleResponderConfig("model1"), - model2: getTestSimpleResponderConfig("model2"), - }, - Profiles: map[string][]string{ - "test": {model1}, - }, - LogLevel: "error", - } - - proxy := New(config) - defer proxy.StopProcesses() - - // actual member of profile - { - reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() - - proxy.HandlerFunc(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "model1") - } - - // actual model, but non-member will 404 - { - reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() - - proxy.HandlerFunc(w, req) - assert.Equal(t, http.StatusNotFound, w.Code) - } -} - func TestProxyManager_Shutdown(t *testing.T) { // make broken model configurations model1Config := getTestSimpleResponderConfigPort("model1", 9991) @@ -273,24 +260,27 @@ func TestProxyManager_Shutdown(t *testing.T) { model3Config := getTestSimpleResponderConfigPort("model3", 9993) model3Config.Proxy = "http://localhost:10003/" - config := &Config{ + config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15, - Profiles: map[string][]string{ - "test": {"model1", "model2", "model3"}, - }, Models: map[string]ModelConfig{ "model1": model1Config, "model2": model2Config, "model3": model3Config, }, LogLevel: "error", - } + Groups: map[string]GroupConfig{ + "test": { + Swap: false, + Members: []string{"model1", "model2", "model3"}, + }, + }, + }) proxy := New(config) // Start all the processes var wg sync.WaitGroup - for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} { + for _, modelName := range []string{"model1", "model2", "model3"} { wg.Add(1) go func(modelName string) { defer wg.Done() @@ -298,11 +288,10 @@ func TestProxyManager_Shutdown(t *testing.T) { req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() - // send a request to trigger the proxy to load + // send a request to trigger the proxy to load ... this should hang waiting for start up proxy.HandlerFunc(w, req) assert.Equal(t, http.StatusBadGateway, w.Code) assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") - //fmt.Println(w.Code, w.Body.String()) }(modelName) } @@ -314,67 +303,44 @@ func TestProxyManager_Shutdown(t *testing.T) { } func TestProxyManager_Unload(t *testing.T) { - config := &Config{ + config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15, Models: map[string]ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, LogLevel: "error", - } + }) proxy := New(config) - proc, err := proxy.swapModel("model1") - assert.NoError(t, err) - assert.NotNil(t, proc) - - assert.Len(t, proxy.currentProcesses, 1) - req := httptest.NewRequest("GET", "/unload", nil) + reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() proxy.HandlerFunc(w, req) + + assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) + req = httptest.NewRequest("GET", "/unload", nil) + w = httptest.NewRecorder() + proxy.HandlerFunc(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, w.Body.String(), "OK") - assert.Len(t, proxy.currentProcesses, 0) -} -// issue 62, strip profile slug from model name -func TestProxyManager_StripProfileSlug(t *testing.T) { - config := &Config{ - HealthCheckTimeout: 15, - Profiles: map[string][]string{ - "test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go - }, - Models: map[string]ModelConfig{ - "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), - }, - LogLevel: "error", - } - - proxy := New(config) - defer proxy.StopProcesses() - - reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel") - req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "ok") + // give it a bit of time to stop + <-time.After(time.Millisecond * 250) + assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped) } // Test issue #61 `Listing the current list of models and the loaded model.` func TestProxyManager_RunningEndpoint(t *testing.T) { // Shared configuration - config := &Config{ + config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15, Models: map[string]ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), "model2": getTestSimpleResponderConfig("model2"), }, - Profiles: map[string][]string{ - "test": {"model1", "model2"}, - }, - LogLevel: "error", - } + LogLevel: "debug", + }) // Define a helper struct to parse the JSON response. type RunningResponse struct { @@ -429,238 +395,126 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { // Is the model loaded? assert.Equal(t, "ready", response.Running[0].State) }) - - t.Run("multiple models via profile", func(t *testing.T) { - // Load more than one model. - for _, model := range []string{"model1", "model2"} { - profileModel := ProcessKeyName("test", model) - reqBody := fmt.Sprintf(`{"model":"%s"}`, profileModel) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) - assert.Equal(t, http.StatusOK, w.Code) - } - - // Simulate the browser call. - req := httptest.NewRequest("GET", "/running", nil) - w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) - - var response RunningResponse - - // The JSON response must be valid. - assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) - - // The response should contain 2 models. - assert.Len(t, response.Running, 2) - - expectedModels := map[string]struct{}{ - "model1": {}, - "model2": {}, - } - - // Iterate through the models and check their states as well. - for _, entry := range response.Running { - _, exists := expectedModels[entry.Model] - assert.True(t, exists, "unexpected model %s", entry.Model) - assert.Equal(t, "ready", entry.State) - delete(expectedModels, entry.Model) - } - - // Since we deleted each model while testing for its validity we should have no more models in the response. - assert.Empty(t, expectedModels, "unexpected additional models in response") - }) } func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { - config := &Config{ + config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15, - Profiles: map[string][]string{ - "test": {"TheExpectedModel"}, - }, Models: map[string]ModelConfig{ "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), }, LogLevel: "error", - } + }) proxy := New(config) defer proxy.StopProcesses() - testCases := []struct { - name string - modelInput string - expectModel string - }{ - { - name: "With Profile Prefix", - modelInput: "test:TheExpectedModel", - expectModel: "TheExpectedModel", // Profile prefix should be stripped - }, - { - name: "Without Profile Prefix", - modelInput: "TheExpectedModel", - expectModel: "TheExpectedModel", // Should remain the same - }, - } + // Create a buffer with multipart form data + var b bytes.Buffer + w := multipart.NewWriter(&b) - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - // Create a buffer with multipart form data - var b bytes.Buffer - w := multipart.NewWriter(&b) + // Add the model field + fw, err := w.CreateFormField("model") + assert.NoError(t, err) + _, err = fw.Write([]byte("TheExpectedModel")) + assert.NoError(t, err) - // Add the model field - fw, err := w.CreateFormField("model") - assert.NoError(t, err) - _, err = fw.Write([]byte(tc.modelInput)) - assert.NoError(t, err) + // Add a file field + fw, err = w.CreateFormFile("file", "test.mp3") + assert.NoError(t, err) + // Generate random content length between 10 and 20 + contentLength := rand.Intn(11) + 10 // 10 to 20 + content := make([]byte, contentLength) + _, err = fw.Write(content) + assert.NoError(t, err) + w.Close() - // Add a file field - fw, err = w.CreateFormFile("file", "test.mp3") - assert.NoError(t, err) - // Generate random content length between 10 and 20 - contentLength := rand.Intn(11) + 10 // 10 to 20 - content := make([]byte, contentLength) - _, err = fw.Write(content) - assert.NoError(t, err) - w.Close() + // Create the request with the multipart form data + req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) + req.Header.Set("Content-Type", w.FormDataContentType()) + rec := httptest.NewRecorder() + proxy.HandlerFunc(rec, req) - // Create the request with the multipart form data - req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) - req.Header.Set("Content-Type", w.FormDataContentType()) - rec := httptest.NewRecorder() - proxy.HandlerFunc(rec, req) - - // Verify the response - assert.Equal(t, http.StatusOK, rec.Code) - var response map[string]string - err = json.Unmarshal(rec.Body.Bytes(), &response) - assert.NoError(t, err) - assert.Equal(t, tc.expectModel, response["model"]) - assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder - }) - } -} - -func TestProxyManager_SplitRequestedModel(t *testing.T) { - - tests := []struct { - name string - requestedModel string - expectedProfile string - expectedModel string - }{ - {"no profile", "gpt-4", "", "gpt-4"}, - {"with profile", "profile1:gpt-4", "profile1", "gpt-4"}, - {"only profile", "profile1:", "profile1", ""}, - {"empty model", ":gpt-4", "", "gpt-4"}, - {"empty profile", ":", "", ""}, - {"no split char", "gpt-4", "", "gpt-4"}, - {"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - profileName, modelName := splitRequestedModel(tt.requestedModel) - if profileName != tt.expectedProfile { - t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel) - } - if modelName != tt.expectedModel { - t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel) - } - }) - } + // Verify the response + assert.Equal(t, http.StatusOK, rec.Code) + var response map[string]string + err = json.Unmarshal(rec.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "TheExpectedModel", response["model"]) + assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder } // Test useModelName in configuration sends overrides what is sent to upstream func TestProxyManager_UseModelName(t *testing.T) { - upstreamModelName := "upstreamModel" modelConfig := getTestSimpleResponderConfig(upstreamModelName) modelConfig.UseModelName = upstreamModelName - config := &Config{ + config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15, - Profiles: map[string][]string{ - "test": {"model1"}, - }, - Models: map[string]ModelConfig{ "model1": modelConfig, }, - LogLevel: "error", - } + }) proxy := New(config) defer proxy.StopProcesses() - tests := []struct { - description string - requestedModel string - }{ - {"useModelName over rides requested model", "model1"}, - {"useModelName over rides requested profile:model", "test:model1"}, - } + requestedModel := "model1" - for _, tt := range tests { - t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) { - reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel) - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := httptest.NewRecorder() + t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) { + reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), upstreamModelName) + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), upstreamModelName) + }) - }) - } + t.Run("useModelName over rides requested model: /v1/audio/transcriptions", func(t *testing.T) { + // Create a buffer with multipart form data + var b bytes.Buffer + w := multipart.NewWriter(&b) - for _, tt := range tests { - t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) { - // Create a buffer with multipart form data - var b bytes.Buffer - w := multipart.NewWriter(&b) + // Add the model field + fw, err := w.CreateFormField("model") + assert.NoError(t, err) + _, err = fw.Write([]byte(requestedModel)) + assert.NoError(t, err) - // Add the model field - fw, err := w.CreateFormField("model") - assert.NoError(t, err) - _, err = fw.Write([]byte(tt.requestedModel)) - assert.NoError(t, err) + // Add a file field + fw, err = w.CreateFormFile("file", "test.mp3") + assert.NoError(t, err) + _, err = fw.Write([]byte("test")) + assert.NoError(t, err) + w.Close() - // Add a file field - fw, err = w.CreateFormFile("file", "test.mp3") - assert.NoError(t, err) - _, err = fw.Write([]byte("test")) - assert.NoError(t, err) - w.Close() + // Create the request with the multipart form data + req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) + req.Header.Set("Content-Type", w.FormDataContentType()) + rec := httptest.NewRecorder() + proxy.HandlerFunc(rec, req) - // Create the request with the multipart form data - req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) - req.Header.Set("Content-Type", w.FormDataContentType()) - rec := httptest.NewRecorder() - proxy.HandlerFunc(rec, req) - - // Verify the response - assert.Equal(t, http.StatusOK, rec.Code) - var response map[string]string - err = json.Unmarshal(rec.Body.Bytes(), &response) - assert.NoError(t, err) - assert.Equal(t, upstreamModelName, response["model"]) - }) - } + // Verify the response + assert.Equal(t, http.StatusOK, rec.Code) + var response map[string]string + err = json.Unmarshal(rec.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, upstreamModelName, response["model"]) + }) } func TestProxyManager_CORSOptionsHandler(t *testing.T) { - config := &Config{ + config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15, Models: map[string]ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, LogLevel: "error", - } + }) tests := []struct { name string @@ -720,3 +574,21 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) { }) } } + +func TestProxyManager_Upstream(t *testing.T) { + config := AddDefaultGroupToConfig(Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + }, + LogLevel: "error", + }) + + proxy := New(config) + defer proxy.StopProcesses() + req := httptest.NewRequest("GET", "/upstream/model1/test", nil) + rec := httptest.NewRecorder() + proxy.HandlerFunc(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "model1", rec.Body.String()) +}