diff --git a/README.md b/README.md index 4b232d8..376dce3 100644 --- a/README.md +++ b/README.md @@ -2,17 +2,22 @@ ![llama-swap header image](header.jpeg) -[llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models on demand. So let's swap the server on demand instead! +llama-swap is a golang server that automatically swaps the llama.cpp server on demand. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead! -llama-swap is a proxy server that sits in front of llama-server. When a request for `/v1/chat/completions` comes in it will extract the `model` requested and change the underlying llama-server automatically. +Features: -- ✅ easy to deploy: single binary with no dependencies -- ✅ full control over llama-server's startup settings -- ✅ ❤️ for users who are rely on llama.cpp for LLM inference +- ✅ Easy to deploy: single binary with no dependencies +- ✅ Single yaml configuration file +- ✅ Automatically switching between models +- ✅ Full control over llama.cpp server settings per model +- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`) +- ✅ Multiple GPU support +- ✅ Run multiple models at once with `profiles` +- ✅ Remote log monitoring at `/log` ## config.yaml -llama-swap's configuration purposefully simple. +llama-swap's configuration is purposefully simple. ```yaml # Seconds to wait for llama.cpp to load and be ready to serve requests @@ -24,25 +29,24 @@ models: "llama": cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf - # where to reach the server started by cmd + # where to reach the server started by cmd, make sure the ports match proxy: http://127.0.0.1:8999 - # aliases model names to use this configuration for + # aliases names to use this model for aliases: - "gpt-4o-mini" - "gpt-3.5-turbo" - # wait for this path to return an HTTP 200 before serving requests - # defaults to /health to match llama.cpp - # - # use "none" to skip endpoint checking. This may cause requests to fail - # until the server is ready + # check this path for an HTTP 200 OK before serving requests + # default: /health to match llama.cpp + # use "none" to skip endpoint checking, but may cause HTTP errors + # until the model is ready checkEndpoint: /custom-endpoint - # automatically unload the model after 10 seconds + # automatically unload the model after this many seconds # ttl values must be a value greater than 0 # default: 0 = never unload model - ttl: 5 + ttl: 60 "qwen": # environment variables to pass to the command @@ -53,8 +57,18 @@ models: cmd: > llama-server --port 8999 --model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf - proxy: http://127.0.0.1:8999 + +# profiles make it easy to managing multi model (and gpu) configurations. +# +# Tips: +# - each model must be listening on a unique address and port +# - the model name is in this format: "profile_name/model", like "coding/qwen" +# - the profile will load and unload all models in the profile at the same time +profiles: + coding: + - "qwen" + - "llama" ``` ## Installation diff --git a/config.example.yaml b/config.example.yaml index e660b11..583b4b2 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -6,9 +6,9 @@ models: "llama": cmd: > models/llama-server-osx - --port 8999 + --port 9001 -m models/Llama-3.2-1B-Instruct-Q4_0.gguf - proxy: http://127.0.0.1:8999 + proxy: http://127.0.0.1:9001 # list of model name aliases this llama.cpp instance can serve aliases: @@ -21,8 +21,8 @@ models: ttl: 5 "qwen": - cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf - proxy: http://127.0.0.1:8999 + cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf + proxy: http://127.0.0.1:9002 aliases: - gpt-3.5-turbo @@ -44,4 +44,10 @@ models: proxy: http://127.0.0.1:8999 "broken_timeout": cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf - proxy: http://127.0.0.1:9000 \ No newline at end of file + proxy: http://127.0.0.1:9000 + +# creating a coding profile with models for code generation and general questions +profiles: + coding: + - "qwen" + - "llama" \ No newline at end of file diff --git a/misc/simple-responder/simple-responder.go b/misc/simple-responder/simple-responder.go index 61fe3f0..6c9f235 100644 --- a/misc/simple-responder/simple-responder.go +++ b/misc/simple-responder/simple-responder.go @@ -16,12 +16,16 @@ func main() { flag.Parse() // Parse the command-line flags - // Set up the handler function using the provided response message - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + responseMessageHandler := func(w http.ResponseWriter, r *http.Request) { // Set the header to text/plain w.Header().Set("Content-Type", "text/plain") fmt.Fprintln(w, *responseMessage) - }) + } + + // Set up the handler function using the provided response message + http.HandleFunc("/v1/chat/completions", responseMessageHandler) + http.HandleFunc("/v1/completions", responseMessageHandler) + http.HandleFunc("/test", responseMessageHandler) http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") @@ -43,6 +47,11 @@ func main() { w.Write([]byte(response)) }) + http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path) + }) + address := "127.0.0.1:" + *port // Address with the specified port fmt.Printf("Server is listening on port %s\n", *port) diff --git a/proxy/config.go b/proxy/config.go index ae81150..ad0794f 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -22,26 +22,30 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) { } type Config struct { - Models map[string]ModelConfig `yaml:"models"` HealthCheckTimeout int `yaml:"healthCheckTimeout"` + Models map[string]ModelConfig `yaml:"models"` + Profiles map[string][]string `yaml:"profiles"` + + // map aliases to actual model IDs + aliases map[string]string +} + +func (c *Config) RealModelName(search string) (string, bool) { + if _, found := c.Models[search]; found { + return search, true + } else if name, found := c.aliases[search]; found { + return name, found + } else { + return "", false + } } func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) { - modelConfig, found := c.Models[modelName] - if found { - return modelConfig, modelName, true + if realName, found := c.RealModelName(modelName); !found { + return ModelConfig{}, "", false + } else { + return c.Models[realName], realName, true } - - // Search through aliases to find the right config - for actual, config := range c.Models { - for _, alias := range config.Aliases { - if alias == modelName { - return config, actual, true - } - } - } - - return ModelConfig{}, "", false } func LoadConfig(path string) (*Config, error) { @@ -60,6 +64,14 @@ func LoadConfig(path string) (*Config, error) { config.HealthCheckTimeout = 15 } + // Populate the aliases map + config.aliases = make(map[string]string) + for modelName, modelConfig := range config.Models { + for _, alias := range modelConfig.Aliases { + config.aliases[alias] = modelName + } + } + return &config, nil } diff --git a/proxy/config_test.go b/proxy/config_test.go index b45df37..871e7e0 100644 --- a/proxy/config_test.go +++ b/proxy/config_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" ) -func TestLoadConfig(t *testing.T) { +func TestConfig_Load(t *testing.T) { // Create a temporary YAML file for testing tempDir, err := os.MkdirTemp("", "test-config") if err != nil { @@ -17,7 +17,8 @@ func TestLoadConfig(t *testing.T) { defer os.RemoveAll(tempDir) tempFile := filepath.Join(tempDir, "config.yaml") - content := `models: + content := ` +models: model1: cmd: path/to/cmd --arg1 one proxy: "http://localhost:8080" @@ -28,7 +29,17 @@ func TestLoadConfig(t *testing.T) { - "VAR1=value1" - "VAR2=value2" checkEndpoint: "/health" + model2: + cmd: path/to/cmd --arg1 one + proxy: "http://localhost:8081" + aliases: + - "m2" + checkEndpoint: "/" healthCheckTimeout: 15 +profiles: + test: + - model1 + - model2 ` if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { @@ -50,14 +61,33 @@ healthCheckTimeout: 15 Env: []string{"VAR1=value1", "VAR2=value2"}, CheckEndpoint: "/health", }, + "model2": { + Cmd: "path/to/cmd --arg1 one", + Proxy: "http://localhost:8081", + Aliases: []string{"m2"}, + Env: nil, + CheckEndpoint: "/", + }, }, HealthCheckTimeout: 15, + Profiles: map[string][]string{ + "test": {"model1", "model2"}, + }, + aliases: map[string]string{ + "m1": "model1", + "model-one": "model1", + "m2": "model2", + }, } assert.Equal(t, expected, config) + + realname, found := config.RealModelName("m1") + assert.True(t, found) + assert.Equal(t, "model1", realname) } -func TestModelConfigSanitizedCommand(t *testing.T) { +func TestConfig_ModelConfigSanitizedCommand(t *testing.T) { config := &ModelConfig{ Cmd: `python model1.py \ --arg1 value1 \ @@ -69,7 +99,10 @@ func TestModelConfigSanitizedCommand(t *testing.T) { assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args) } -func TestFindConfig(t *testing.T) { +func TestConfig_FindConfig(t *testing.T) { + + // TODO? + // make make this shared between the different tests config := &Config{ Models: map[string]ModelConfig{ "model1": { @@ -88,6 +121,11 @@ func TestFindConfig(t *testing.T) { }, }, HealthCheckTimeout: 10, + aliases: map[string]string{ + "m1": "model1", + "model-one": "model1", + "m2": "model2", + }, } // Test finding a model by its name @@ -109,7 +147,7 @@ func TestFindConfig(t *testing.T) { assert.Equal(t, ModelConfig{}, modelConfig) } -func TestSanitizeCommand(t *testing.T) { +func TestConfig_SanitizeCommand(t *testing.T) { // Test a simple command args, err := SanitizeCommand("python model1.py") assert.NoError(t, err) diff --git a/proxy/helpers_test.go b/proxy/helpers_test.go new file mode 100644 index 0000000..e8efa5b --- /dev/null +++ b/proxy/helpers_test.go @@ -0,0 +1,58 @@ +package proxy + +import ( + "fmt" + "os" + "path/filepath" + "runtime" + "sync" + "testing" + + "github.com/gin-gonic/gin" +) + +var ( + nextTestPort int = 12000 + portMutex sync.Mutex +) + +// Check if the binary exists +func TestMain(m *testing.M) { + binaryPath := getSimpleResponderPath() + if _, err := os.Stat(binaryPath); os.IsNotExist(err) { + fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath) + os.Exit(1) + } + + gin.SetMode(gin.TestMode) + + m.Run() +} + +// Helper function to get the binary path +func getSimpleResponderPath() string { + goos := runtime.GOOS + goarch := runtime.GOARCH + return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch)) +} + +func getTestSimpleResponderConfig(expectedMessage string) ModelConfig { + portMutex.Lock() + defer portMutex.Unlock() + + port := nextTestPort + nextTestPort++ + + return getTestSimpleResponderConfigPort(expectedMessage, port) +} + +func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig { + binaryPath := getSimpleResponderPath() + + // Create a process configuration + return ModelConfig{ + Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, port, expectedMessage), + Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), + CheckEndpoint: "/health", + } +} diff --git a/proxy/process.go b/proxy/process.go index f4d30c6..10212f3 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -113,7 +113,7 @@ func (p *Process) Stop() { p.Lock() defer p.Unlock() - if !p.isRunning { + if !p.isRunning || p.cmd == nil || p.cmd.Process == nil { return } diff --git a/proxy/process_test.go b/proxy/process_test.go index 34221c8..5b2481c 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -1,54 +1,15 @@ package proxy import ( - "fmt" "io" - "math/rand" "net/http" "net/http/httptest" - "os" - "path/filepath" - "runtime" "testing" "time" "github.com/stretchr/testify/assert" ) -// Check if the binary exists -func TestMain(m *testing.M) { - binaryPath := getBinaryPath() - if _, err := os.Stat(binaryPath); os.IsNotExist(err) { - fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath) - os.Exit(1) - } - m.Run() -} - -// Helper function to get the binary path -func getBinaryPath() string { - goos := runtime.GOOS - goarch := runtime.GOARCH - return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch)) -} - -func getTestSimpleResponderConfig(expectedMessage string) ModelConfig { - // Define the range - min := 12000 - max := 13000 - - // Generate a random number between 12000 and 13000 - randomPort := rand.Intn(max-min+1) + min - binaryPath := getBinaryPath() - - // Create a process configuration - return ModelConfig{ - Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, randomPort, expectedMessage), - Proxy: fmt.Sprintf("http://127.0.0.1:%d", randomPort), - CheckEndpoint: "/health", - } -} - func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { logMonitor := NewLogMonitorWriter(io.Discard) expectedMessage := "testing91931" @@ -56,7 +17,9 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { // Create a process process := NewProcess("test-process", 5, config, logMonitor) - req := httptest.NewRequest("GET", "/", nil) + defer process.Stop() + + req := httptest.NewRequest("GET", "/test", nil) w := httptest.NewRecorder() // process is automatically started @@ -92,6 +55,8 @@ func TestProcess_BrokenModelConfig(t *testing.T) { } process := NewProcess("broken", 1, config, NewLogMonitor()) + defer process.Stop() + req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() process.ProxyRequest(w, req) @@ -99,6 +64,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) { assert.Contains(t, w.Body.String(), "unable to start process") } +// test that the process unloads after the TTL func TestProcess_UnloadAfterTTL(t *testing.T) { if testing.Short() { t.Skip("skipping long auto unload TTL test") @@ -111,7 +77,9 @@ func TestProcess_UnloadAfterTTL(t *testing.T) { assert.Equal(t, 3, config.UnloadAfter) process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard)) - req := httptest.NewRequest("GET", "/", nil) + defer process.Stop() + + req := httptest.NewRequest("GET", "/test", nil) w := httptest.NewRecorder() // Proxy the request (auto start) diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 601fa6e..3c3b874 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "strconv" + "strings" "sync" "time" @@ -16,18 +17,18 @@ import ( type ProxyManager struct { sync.Mutex - config *Config - currentProcess *Process - logMonitor *LogMonitor - ginEngine *gin.Engine + config *Config + currentProcesses map[string]*Process + logMonitor *LogMonitor + ginEngine *gin.Engine } func New(config *Config) *ProxyManager { pm := &ProxyManager{ - config: config, - currentProcess: nil, - logMonitor: NewLogMonitor(), - ginEngine: gin.New(), + config: config, + currentProcesses: make(map[string]*Process), + logMonitor: NewLogMonitor(), + ginEngine: gin.New(), } // Set up routes using the Gin engine @@ -43,7 +44,7 @@ func New(config *Config) *ProxyManager { pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) - pm.ginEngine.NoRoute(pm.proxyRequestHandler) + pm.ginEngine.NoRoute(pm.proxyNoRouteHandler) // Disable console color for testing gin.DisableConsoleColor() @@ -59,6 +60,21 @@ func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) { pm.ginEngine.ServeHTTP(w, r) } +func (pm *ProxyManager) StopProcesses() { + pm.Lock() + defer pm.Unlock() + + pm.stopProcesses() +} + +// for internal usage +func (pm *ProxyManager) stopProcesses() { + for _, process := range pm.currentProcesses { + process.Stop() + } + pm.currentProcesses = make(map[string]*Process) +} + func (pm *ProxyManager) listModelsHandler(c *gin.Context) { data := []interface{}{} for id := range pm.config.Models { @@ -80,27 +96,64 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) { } } -func (pm *ProxyManager) swapModel(requestedModel string) error { +func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) { pm.Lock() defer pm.Unlock() - // find the model configuration matching requestedModel - modelConfig, modelID, found := pm.config.FindConfig(requestedModel) - if !found { - return fmt.Errorf("could not find configuration for %s", requestedModel) + // Check if requestedModel contains a / + groupName, modelName := "", requestedModel + if idx := strings.Index(requestedModel, "/"); idx != -1 { + groupName = requestedModel[:idx] + modelName = requestedModel[idx+1:] } - // do nothing as it's already the correct process - if pm.currentProcess != nil { - if pm.currentProcess.ID == modelID { - return nil - } else { - pm.currentProcess.Stop() + if groupName != "" { + if _, found := pm.config.Profiles[groupName]; !found { + return nil, fmt.Errorf("model group not found %s", groupName) } } - pm.currentProcess = NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor) - return nil + // de-alias the real model name and get a real one + realModelName, found := pm.config.RealModelName(modelName) + if !found { + return nil, fmt.Errorf("could not find modelID for %s", requestedModel) + } + + // exit early when already running, otherwise stop everything and swap + requestedProcessKey := groupName + "/" + realModelName + if process, found := pm.currentProcesses[requestedProcessKey]; found { + return process, nil + } + + // stop all running models + pm.stopProcesses() + + if groupName == "" { + modelConfig, modelID, found := pm.config.FindConfig(realModelName) + if !found { + return nil, fmt.Errorf("could not find configuration for %s", realModelName) + } + + process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor) + processKey := groupName + "/" + modelID + pm.currentProcesses[processKey] = process + } else { + for _, modelName := range pm.config.Profiles[groupName] { + if realModelName, found := pm.config.RealModelName(modelName); found { + modelConfig, modelID, found := pm.config.FindConfig(realModelName) + if !found { + return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, groupName) + } + + process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor) + processKey := groupName + "/" + modelID + pm.currentProcesses[processKey] = process + } + } + } + + // requestedProcessKey should exist due to swap + return pm.currentProcesses[requestedProcessKey], nil } func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) { @@ -120,24 +173,27 @@ func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) { return } - if err := pm.swapModel(model); err != nil { + if process, err := pm.swapModel(model); err != nil { c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error())) return + } else { + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + // dechunk it as we already have all the body bytes see issue #11 + c.Request.Header.Del("transfer-encoding") + c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) + + process.ProxyRequest(c.Writer, c.Request) + } + +} + +func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) { + // since maps are unordered, just use the first available process if one exists + for _, process := range pm.currentProcesses { + process.ProxyRequest(c.Writer, c.Request) + return } - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - - // dechunk it as we already have all the body bytes see issue #11 - c.Request.Header.Del("transfer-encoding") - c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) - - pm.currentProcess.ProxyRequest(c.Writer, c.Request) -} - -func (pm *ProxyManager) proxyRequestHandler(c *gin.Context) { - if pm.currentProcess != nil { - pm.currentProcess.ProxyRequest(c.Writer, c.Request) - } else { - c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request")) - } + c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request")) } diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go new file mode 100644 index 0000000..2d15d7e --- /dev/null +++ b/proxy/proxymanager_test.go @@ -0,0 +1,76 @@ +package proxy + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestProxyManager_SwapProcessCorrectly(t *testing.T) { + config := &Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), + }, + } + + proxy := New(config) + defer proxy.StopProcesses() + + for _, modelName := range []string{"model1", "model2"} { + reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), modelName) + + _, exists := proxy.currentProcesses["/"+modelName] + assert.True(t, exists, "expected %s key in currentProcesses", modelName) + + } + + // make sure there's only one loaded model + assert.Len(t, proxy.currentProcesses, 1) +} + +func TestProxyManager_SwapMultiProcess(t *testing.T) { + config := &Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), + }, + Profiles: map[string][]string{ + "test": {"model1", "model2"}, + }, + } + + proxy := New(config) + defer proxy.StopProcesses() + + for modelID, requestedModel := range map[string]string{"model1": "test/model1", "model2": "test/model2"} { + reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), modelID) + } + + // make sure there's two loaded models + assert.Len(t, proxy.currentProcesses, 2) + _, exists := proxy.currentProcesses["test/model1"] + assert.True(t, exists, "expected test/model1 key in currentProcesses") + + _, exists = proxy.currentProcesses["test/model2"] + assert.True(t, exists, "expected test/model2 key in currentProcesses") + +}