Implement Multi-Process Handling (#7)

Refactor code to support starting of multiple back end llama.cpp servers. This functionality is exposed as `profiles` to create a simple configuration format. 

Changes: 

* refactor proxy tests to get ready for multi-process support
* update proxy/ProxyManager to support multiple processes (#7)
* Add support for Groups in configuration
* improve handling of Model alias configs
* implement multi-model swapping
* improve code clarity for swapModel
* improve docs, rename groups to profiles in config
This commit is contained in:
Benson Wong
2024-11-23 19:45:13 -08:00
committed by GitHub
parent 533162ce6a
commit 73ad85ea69
10 changed files with 361 additions and 124 deletions

View File

@@ -2,17 +2,22 @@
![llama-swap header image](header.jpeg) ![llama-swap header image](header.jpeg)
[llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models on demand. So let's swap the server on demand instead! llama-swap is a golang server that automatically swaps the llama.cpp server on demand. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
llama-swap is a proxy server that sits in front of llama-server. When a request for `/v1/chat/completions` comes in it will extract the `model` requested and change the underlying llama-server automatically. Features:
-easy to deploy: single binary with no dependencies -Easy to deploy: single binary with no dependencies
-full control over llama-server's startup settings -Single yaml configuration file
-❤️ for users who are rely on llama.cpp for LLM inference -Automatically switching between models
- ✅ Full control over llama.cpp server settings per model
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
- ✅ Multiple GPU support
- ✅ Run multiple models at once with `profiles`
- ✅ Remote log monitoring at `/log`
## config.yaml ## config.yaml
llama-swap's configuration purposefully simple. llama-swap's configuration is purposefully simple.
```yaml ```yaml
# Seconds to wait for llama.cpp to load and be ready to serve requests # Seconds to wait for llama.cpp to load and be ready to serve requests
@@ -24,25 +29,24 @@ models:
"llama": "llama":
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
# where to reach the server started by cmd # where to reach the server started by cmd, make sure the ports match
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
# aliases model names to use this configuration for # aliases names to use this model for
aliases: aliases:
- "gpt-4o-mini" - "gpt-4o-mini"
- "gpt-3.5-turbo" - "gpt-3.5-turbo"
# wait for this path to return an HTTP 200 before serving requests # check this path for an HTTP 200 OK before serving requests
# defaults to /health to match llama.cpp # default: /health to match llama.cpp
# # use "none" to skip endpoint checking, but may cause HTTP errors
# use "none" to skip endpoint checking. This may cause requests to fail # until the model is ready
# until the server is ready
checkEndpoint: /custom-endpoint checkEndpoint: /custom-endpoint
# automatically unload the model after 10 seconds # automatically unload the model after this many seconds
# ttl values must be a value greater than 0 # ttl values must be a value greater than 0
# default: 0 = never unload model # default: 0 = never unload model
ttl: 5 ttl: 60
"qwen": "qwen":
# environment variables to pass to the command # environment variables to pass to the command
@@ -53,8 +57,18 @@ models:
cmd: > cmd: >
llama-server --port 8999 llama-server --port 8999
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf --model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
# profiles make it easy to managing multi model (and gpu) configurations.
#
# Tips:
# - each model must be listening on a unique address and port
# - the model name is in this format: "profile_name/model", like "coding/qwen"
# - the profile will load and unload all models in the profile at the same time
profiles:
coding:
- "qwen"
- "llama"
``` ```
## Installation ## Installation

View File

@@ -6,9 +6,9 @@ models:
"llama": "llama":
cmd: > cmd: >
models/llama-server-osx models/llama-server-osx
--port 8999 --port 9001
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf -m models/Llama-3.2-1B-Instruct-Q4_0.gguf
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:9001
# list of model name aliases this llama.cpp instance can serve # list of model name aliases this llama.cpp instance can serve
aliases: aliases:
@@ -21,8 +21,8 @@ models:
ttl: 5 ttl: 5
"qwen": "qwen":
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:9002
aliases: aliases:
- gpt-3.5-turbo - gpt-3.5-turbo
@@ -44,4 +44,10 @@ models:
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999
"broken_timeout": "broken_timeout":
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:9000 proxy: http://127.0.0.1:9000
# creating a coding profile with models for code generation and general questions
profiles:
coding:
- "qwen"
- "llama"

View File

@@ -16,12 +16,16 @@ func main() {
flag.Parse() // Parse the command-line flags flag.Parse() // Parse the command-line flags
// Set up the handler function using the provided response message responseMessageHandler := func(w http.ResponseWriter, r *http.Request) {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Set the header to text/plain // Set the header to text/plain
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
fmt.Fprintln(w, *responseMessage) fmt.Fprintln(w, *responseMessage)
}) }
// Set up the handler function using the provided response message
http.HandleFunc("/v1/chat/completions", responseMessageHandler)
http.HandleFunc("/v1/completions", responseMessageHandler)
http.HandleFunc("/test", responseMessageHandler)
http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
@@ -43,6 +47,11 @@ func main() {
w.Write([]byte(response)) w.Write([]byte(response))
}) })
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
})
address := "127.0.0.1:" + *port // Address with the specified port address := "127.0.0.1:" + *port // Address with the specified port
fmt.Printf("Server is listening on port %s\n", *port) fmt.Printf("Server is listening on port %s\n", *port)

View File

@@ -22,26 +22,30 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
} }
type Config struct { type Config struct {
Models map[string]ModelConfig `yaml:"models"`
HealthCheckTimeout int `yaml:"healthCheckTimeout"` HealthCheckTimeout int `yaml:"healthCheckTimeout"`
Models map[string]ModelConfig `yaml:"models"`
Profiles map[string][]string `yaml:"profiles"`
// map aliases to actual model IDs
aliases map[string]string
}
func (c *Config) RealModelName(search string) (string, bool) {
if _, found := c.Models[search]; found {
return search, true
} else if name, found := c.aliases[search]; found {
return name, found
} else {
return "", false
}
} }
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) { func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
modelConfig, found := c.Models[modelName] if realName, found := c.RealModelName(modelName); !found {
if found { return ModelConfig{}, "", false
return modelConfig, modelName, true } else {
return c.Models[realName], realName, true
} }
// Search through aliases to find the right config
for actual, config := range c.Models {
for _, alias := range config.Aliases {
if alias == modelName {
return config, actual, true
}
}
}
return ModelConfig{}, "", false
} }
func LoadConfig(path string) (*Config, error) { func LoadConfig(path string) (*Config, error) {
@@ -60,6 +64,14 @@ func LoadConfig(path string) (*Config, error) {
config.HealthCheckTimeout = 15 config.HealthCheckTimeout = 15
} }
// Populate the aliases map
config.aliases = make(map[string]string)
for modelName, modelConfig := range config.Models {
for _, alias := range modelConfig.Aliases {
config.aliases[alias] = modelName
}
}
return &config, nil return &config, nil
} }

View File

@@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestLoadConfig(t *testing.T) { func TestConfig_Load(t *testing.T) {
// Create a temporary YAML file for testing // Create a temporary YAML file for testing
tempDir, err := os.MkdirTemp("", "test-config") tempDir, err := os.MkdirTemp("", "test-config")
if err != nil { if err != nil {
@@ -17,7 +17,8 @@ func TestLoadConfig(t *testing.T) {
defer os.RemoveAll(tempDir) defer os.RemoveAll(tempDir)
tempFile := filepath.Join(tempDir, "config.yaml") tempFile := filepath.Join(tempDir, "config.yaml")
content := `models: content := `
models:
model1: model1:
cmd: path/to/cmd --arg1 one cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8080" proxy: "http://localhost:8080"
@@ -28,7 +29,17 @@ func TestLoadConfig(t *testing.T) {
- "VAR1=value1" - "VAR1=value1"
- "VAR2=value2" - "VAR2=value2"
checkEndpoint: "/health" checkEndpoint: "/health"
model2:
cmd: path/to/cmd --arg1 one
proxy: "http://localhost:8081"
aliases:
- "m2"
checkEndpoint: "/"
healthCheckTimeout: 15 healthCheckTimeout: 15
profiles:
test:
- model1
- model2
` `
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil { if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
@@ -50,14 +61,33 @@ healthCheckTimeout: 15
Env: []string{"VAR1=value1", "VAR2=value2"}, Env: []string{"VAR1=value1", "VAR2=value2"},
CheckEndpoint: "/health", CheckEndpoint: "/health",
}, },
"model2": {
Cmd: "path/to/cmd --arg1 one",
Proxy: "http://localhost:8081",
Aliases: []string{"m2"},
Env: nil,
CheckEndpoint: "/",
},
}, },
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
aliases: map[string]string{
"m1": "model1",
"model-one": "model1",
"m2": "model2",
},
} }
assert.Equal(t, expected, config) assert.Equal(t, expected, config)
realname, found := config.RealModelName("m1")
assert.True(t, found)
assert.Equal(t, "model1", realname)
} }
func TestModelConfigSanitizedCommand(t *testing.T) { func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
config := &ModelConfig{ config := &ModelConfig{
Cmd: `python model1.py \ Cmd: `python model1.py \
--arg1 value1 \ --arg1 value1 \
@@ -69,7 +99,10 @@ func TestModelConfigSanitizedCommand(t *testing.T) {
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args) assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
} }
func TestFindConfig(t *testing.T) { func TestConfig_FindConfig(t *testing.T) {
// TODO?
// make make this shared between the different tests
config := &Config{ config := &Config{
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": { "model1": {
@@ -88,6 +121,11 @@ func TestFindConfig(t *testing.T) {
}, },
}, },
HealthCheckTimeout: 10, HealthCheckTimeout: 10,
aliases: map[string]string{
"m1": "model1",
"model-one": "model1",
"m2": "model2",
},
} }
// Test finding a model by its name // Test finding a model by its name
@@ -109,7 +147,7 @@ func TestFindConfig(t *testing.T) {
assert.Equal(t, ModelConfig{}, modelConfig) assert.Equal(t, ModelConfig{}, modelConfig)
} }
func TestSanitizeCommand(t *testing.T) { func TestConfig_SanitizeCommand(t *testing.T) {
// Test a simple command // Test a simple command
args, err := SanitizeCommand("python model1.py") args, err := SanitizeCommand("python model1.py")
assert.NoError(t, err) assert.NoError(t, err)

58
proxy/helpers_test.go Normal file
View File

@@ -0,0 +1,58 @@
package proxy
import (
"fmt"
"os"
"path/filepath"
"runtime"
"sync"
"testing"
"github.com/gin-gonic/gin"
)
var (
nextTestPort int = 12000
portMutex sync.Mutex
)
// Check if the binary exists
func TestMain(m *testing.M) {
binaryPath := getSimpleResponderPath()
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
os.Exit(1)
}
gin.SetMode(gin.TestMode)
m.Run()
}
// Helper function to get the binary path
func getSimpleResponderPath() string {
goos := runtime.GOOS
goarch := runtime.GOARCH
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
portMutex.Lock()
defer portMutex.Unlock()
port := nextTestPort
nextTestPort++
return getTestSimpleResponderConfigPort(expectedMessage, port)
}
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
binaryPath := getSimpleResponderPath()
// Create a process configuration
return ModelConfig{
Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, port, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health",
}
}

View File

@@ -113,7 +113,7 @@ func (p *Process) Stop() {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if !p.isRunning { if !p.isRunning || p.cmd == nil || p.cmd.Process == nil {
return return
} }

View File

@@ -1,54 +1,15 @@
package proxy package proxy
import ( import (
"fmt"
"io" "io"
"math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath"
"runtime"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
// Check if the binary exists
func TestMain(m *testing.M) {
binaryPath := getBinaryPath()
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
os.Exit(1)
}
m.Run()
}
// Helper function to get the binary path
func getBinaryPath() string {
goos := runtime.GOOS
goarch := runtime.GOARCH
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
}
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
// Define the range
min := 12000
max := 13000
// Generate a random number between 12000 and 13000
randomPort := rand.Intn(max-min+1) + min
binaryPath := getBinaryPath()
// Create a process configuration
return ModelConfig{
Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, randomPort, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", randomPort),
CheckEndpoint: "/health",
}
}
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard) logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931" expectedMessage := "testing91931"
@@ -56,7 +17,9 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
// Create a process // Create a process
process := NewProcess("test-process", 5, config, logMonitor) process := NewProcess("test-process", 5, config, logMonitor)
req := httptest.NewRequest("GET", "/", nil) defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
// process is automatically started // process is automatically started
@@ -92,6 +55,8 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
} }
process := NewProcess("broken", 1, config, NewLogMonitor()) process := NewProcess("broken", 1, config, NewLogMonitor())
defer process.Stop()
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
process.ProxyRequest(w, req) process.ProxyRequest(w, req)
@@ -99,6 +64,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
assert.Contains(t, w.Body.String(), "unable to start process") assert.Contains(t, w.Body.String(), "unable to start process")
} }
// test that the process unloads after the TTL
func TestProcess_UnloadAfterTTL(t *testing.T) { func TestProcess_UnloadAfterTTL(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("skipping long auto unload TTL test") t.Skip("skipping long auto unload TTL test")
@@ -111,7 +77,9 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
assert.Equal(t, 3, config.UnloadAfter) assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard)) process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
req := httptest.NewRequest("GET", "/", nil) defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
// Proxy the request (auto start) // Proxy the request (auto start)

View File

@@ -7,6 +7,7 @@ import (
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@@ -16,18 +17,18 @@ import (
type ProxyManager struct { type ProxyManager struct {
sync.Mutex sync.Mutex
config *Config config *Config
currentProcess *Process currentProcesses map[string]*Process
logMonitor *LogMonitor logMonitor *LogMonitor
ginEngine *gin.Engine ginEngine *gin.Engine
} }
func New(config *Config) *ProxyManager { func New(config *Config) *ProxyManager {
pm := &ProxyManager{ pm := &ProxyManager{
config: config, config: config,
currentProcess: nil, currentProcesses: make(map[string]*Process),
logMonitor: NewLogMonitor(), logMonitor: NewLogMonitor(),
ginEngine: gin.New(), ginEngine: gin.New(),
} }
// Set up routes using the Gin engine // Set up routes using the Gin engine
@@ -43,7 +44,7 @@ func New(config *Config) *ProxyManager {
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.NoRoute(pm.proxyRequestHandler) pm.ginEngine.NoRoute(pm.proxyNoRouteHandler)
// Disable console color for testing // Disable console color for testing
gin.DisableConsoleColor() gin.DisableConsoleColor()
@@ -59,6 +60,21 @@ func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r) pm.ginEngine.ServeHTTP(w, r)
} }
func (pm *ProxyManager) StopProcesses() {
pm.Lock()
defer pm.Unlock()
pm.stopProcesses()
}
// for internal usage
func (pm *ProxyManager) stopProcesses() {
for _, process := range pm.currentProcesses {
process.Stop()
}
pm.currentProcesses = make(map[string]*Process)
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) { func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{} data := []interface{}{}
for id := range pm.config.Models { for id := range pm.config.Models {
@@ -80,27 +96,64 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
} }
} }
func (pm *ProxyManager) swapModel(requestedModel string) error { func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
// find the model configuration matching requestedModel // Check if requestedModel contains a /
modelConfig, modelID, found := pm.config.FindConfig(requestedModel) groupName, modelName := "", requestedModel
if !found { if idx := strings.Index(requestedModel, "/"); idx != -1 {
return fmt.Errorf("could not find configuration for %s", requestedModel) groupName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
} }
// do nothing as it's already the correct process if groupName != "" {
if pm.currentProcess != nil { if _, found := pm.config.Profiles[groupName]; !found {
if pm.currentProcess.ID == modelID { return nil, fmt.Errorf("model group not found %s", groupName)
return nil
} else {
pm.currentProcess.Stop()
} }
} }
pm.currentProcess = NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor) // de-alias the real model name and get a real one
return nil realModelName, found := pm.config.RealModelName(modelName)
if !found {
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
}
// exit early when already running, otherwise stop everything and swap
requestedProcessKey := groupName + "/" + realModelName
if process, found := pm.currentProcesses[requestedProcessKey]; found {
return process, nil
}
// stop all running models
pm.stopProcesses()
if groupName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := groupName + "/" + modelID
pm.currentProcesses[processKey] = process
} else {
for _, modelName := range pm.config.Profiles[groupName] {
if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, groupName)
}
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := groupName + "/" + modelID
pm.currentProcesses[processKey] = process
}
}
}
// requestedProcessKey should exist due to swap
return pm.currentProcesses[requestedProcessKey], nil
} }
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) { func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
@@ -120,24 +173,27 @@ func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
return return
} }
if err := pm.swapModel(model); err != nil { if process, err := pm.swapModel(model); err != nil {
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error())) c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
return return
} else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
}
}
func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
// since maps are unordered, just use the first available process if one exists
for _, process := range pm.currentProcesses {
process.ProxyRequest(c.Writer, c.Request)
return
} }
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
pm.currentProcess.ProxyRequest(c.Writer, c.Request)
}
func (pm *ProxyManager) proxyRequestHandler(c *gin.Context) {
if pm.currentProcess != nil {
pm.currentProcess.ProxyRequest(c.Writer, c.Request)
} else {
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
}
} }

View File

@@ -0,0 +1,76 @@
package proxy
import (
"bytes"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
)
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses["/"+modelName]
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
}
// make sure there's only one loaded model
assert.Len(t, proxy.currentProcesses, 1)
}
func TestProxyManager_SwapMultiProcess(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
}
proxy := New(config)
defer proxy.StopProcesses()
for modelID, requestedModel := range map[string]string{"model1": "test/model1", "model2": "test/model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelID)
}
// make sure there's two loaded models
assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses["test/model1"]
assert.True(t, exists, "expected test/model1 key in currentProcesses")
_, exists = proxy.currentProcesses["test/model2"]
assert.True(t, exists, "expected test/model2 key in currentProcesses")
}