Implement Multi-Process Handling (#7)
Refactor code to support starting of multiple back end llama.cpp servers. This functionality is exposed as `profiles` to create a simple configuration format. Changes: * refactor proxy tests to get ready for multi-process support * update proxy/ProxyManager to support multiple processes (#7) * Add support for Groups in configuration * improve handling of Model alias configs * implement multi-model swapping * improve code clarity for swapModel * improve docs, rename groups to profiles in config
This commit is contained in:
46
README.md
46
README.md
@@ -2,17 +2,22 @@
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
[llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models on demand. So let's swap the server on demand instead!
|
llama-swap is a golang server that automatically swaps the llama.cpp server on demand. Since [llama.cpp's server](https://github.com/ggerganov/llama.cpp/tree/master/examples/server) can't swap models, let's swap the server instead!
|
||||||
|
|
||||||
llama-swap is a proxy server that sits in front of llama-server. When a request for `/v1/chat/completions` comes in it will extract the `model` requested and change the underlying llama-server automatically.
|
Features:
|
||||||
|
|
||||||
- ✅ easy to deploy: single binary with no dependencies
|
- ✅ Easy to deploy: single binary with no dependencies
|
||||||
- ✅ full control over llama-server's startup settings
|
- ✅ Single yaml configuration file
|
||||||
- ✅ ❤️ for users who are rely on llama.cpp for LLM inference
|
- ✅ Automatically switching between models
|
||||||
|
- ✅ Full control over llama.cpp server settings per model
|
||||||
|
- ✅ OpenAI API support (`v1/completions` and `v1/chat/completions`)
|
||||||
|
- ✅ Multiple GPU support
|
||||||
|
- ✅ Run multiple models at once with `profiles`
|
||||||
|
- ✅ Remote log monitoring at `/log`
|
||||||
|
|
||||||
## config.yaml
|
## config.yaml
|
||||||
|
|
||||||
llama-swap's configuration purposefully simple.
|
llama-swap's configuration is purposefully simple.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
# Seconds to wait for llama.cpp to load and be ready to serve requests
|
||||||
@@ -24,25 +29,24 @@ models:
|
|||||||
"llama":
|
"llama":
|
||||||
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
cmd: llama-server --port 8999 -m Llama-3.2-1B-Instruct-Q4_K_M.gguf
|
||||||
|
|
||||||
# where to reach the server started by cmd
|
# where to reach the server started by cmd, make sure the ports match
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
# aliases model names to use this configuration for
|
# aliases names to use this model for
|
||||||
aliases:
|
aliases:
|
||||||
- "gpt-4o-mini"
|
- "gpt-4o-mini"
|
||||||
- "gpt-3.5-turbo"
|
- "gpt-3.5-turbo"
|
||||||
|
|
||||||
# wait for this path to return an HTTP 200 before serving requests
|
# check this path for an HTTP 200 OK before serving requests
|
||||||
# defaults to /health to match llama.cpp
|
# default: /health to match llama.cpp
|
||||||
#
|
# use "none" to skip endpoint checking, but may cause HTTP errors
|
||||||
# use "none" to skip endpoint checking. This may cause requests to fail
|
# until the model is ready
|
||||||
# until the server is ready
|
|
||||||
checkEndpoint: /custom-endpoint
|
checkEndpoint: /custom-endpoint
|
||||||
|
|
||||||
# automatically unload the model after 10 seconds
|
# automatically unload the model after this many seconds
|
||||||
# ttl values must be a value greater than 0
|
# ttl values must be a value greater than 0
|
||||||
# default: 0 = never unload model
|
# default: 0 = never unload model
|
||||||
ttl: 5
|
ttl: 60
|
||||||
|
|
||||||
"qwen":
|
"qwen":
|
||||||
# environment variables to pass to the command
|
# environment variables to pass to the command
|
||||||
@@ -53,8 +57,18 @@ models:
|
|||||||
cmd: >
|
cmd: >
|
||||||
llama-server --port 8999
|
llama-server --port 8999
|
||||||
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
--model path/to/Qwen2.5-1.5B-Instruct-Q4_K_M.gguf
|
||||||
|
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
|
|
||||||
|
# profiles make it easy to managing multi model (and gpu) configurations.
|
||||||
|
#
|
||||||
|
# Tips:
|
||||||
|
# - each model must be listening on a unique address and port
|
||||||
|
# - the model name is in this format: "profile_name/model", like "coding/qwen"
|
||||||
|
# - the profile will load and unload all models in the profile at the same time
|
||||||
|
profiles:
|
||||||
|
coding:
|
||||||
|
- "qwen"
|
||||||
|
- "llama"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|||||||
@@ -6,9 +6,9 @@ models:
|
|||||||
"llama":
|
"llama":
|
||||||
cmd: >
|
cmd: >
|
||||||
models/llama-server-osx
|
models/llama-server-osx
|
||||||
--port 8999
|
--port 9001
|
||||||
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
-m models/Llama-3.2-1B-Instruct-Q4_0.gguf
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:9001
|
||||||
|
|
||||||
# list of model name aliases this llama.cpp instance can serve
|
# list of model name aliases this llama.cpp instance can serve
|
||||||
aliases:
|
aliases:
|
||||||
@@ -21,8 +21,8 @@ models:
|
|||||||
ttl: 5
|
ttl: 5
|
||||||
|
|
||||||
"qwen":
|
"qwen":
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
cmd: models/llama-server-osx --port 9002 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:9002
|
||||||
aliases:
|
aliases:
|
||||||
- gpt-3.5-turbo
|
- gpt-3.5-turbo
|
||||||
|
|
||||||
@@ -44,4 +44,10 @@ models:
|
|||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
"broken_timeout":
|
"broken_timeout":
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||||
proxy: http://127.0.0.1:9000
|
proxy: http://127.0.0.1:9000
|
||||||
|
|
||||||
|
# creating a coding profile with models for code generation and general questions
|
||||||
|
profiles:
|
||||||
|
coding:
|
||||||
|
- "qwen"
|
||||||
|
- "llama"
|
||||||
@@ -16,12 +16,16 @@ func main() {
|
|||||||
|
|
||||||
flag.Parse() // Parse the command-line flags
|
flag.Parse() // Parse the command-line flags
|
||||||
|
|
||||||
// Set up the handler function using the provided response message
|
responseMessageHandler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// Set the header to text/plain
|
// Set the header to text/plain
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
fmt.Fprintln(w, *responseMessage)
|
fmt.Fprintln(w, *responseMessage)
|
||||||
})
|
}
|
||||||
|
|
||||||
|
// Set up the handler function using the provided response message
|
||||||
|
http.HandleFunc("/v1/chat/completions", responseMessageHandler)
|
||||||
|
http.HandleFunc("/v1/completions", responseMessageHandler)
|
||||||
|
http.HandleFunc("/test", responseMessageHandler)
|
||||||
|
|
||||||
http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) {
|
http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
@@ -43,6 +47,11 @@ func main() {
|
|||||||
w.Write([]byte(response))
|
w.Write([]byte(response))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path)
|
||||||
|
})
|
||||||
|
|
||||||
address := "127.0.0.1:" + *port // Address with the specified port
|
address := "127.0.0.1:" + *port // Address with the specified port
|
||||||
fmt.Printf("Server is listening on port %s\n", *port)
|
fmt.Printf("Server is listening on port %s\n", *port)
|
||||||
|
|
||||||
|
|||||||
@@ -22,26 +22,30 @@ func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Models map[string]ModelConfig `yaml:"models"`
|
|
||||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||||
|
Models map[string]ModelConfig `yaml:"models"`
|
||||||
|
Profiles map[string][]string `yaml:"profiles"`
|
||||||
|
|
||||||
|
// map aliases to actual model IDs
|
||||||
|
aliases map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
|
if _, found := c.Models[search]; found {
|
||||||
|
return search, true
|
||||||
|
} else if name, found := c.aliases[search]; found {
|
||||||
|
return name, found
|
||||||
|
} else {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||||
modelConfig, found := c.Models[modelName]
|
if realName, found := c.RealModelName(modelName); !found {
|
||||||
if found {
|
return ModelConfig{}, "", false
|
||||||
return modelConfig, modelName, true
|
} else {
|
||||||
|
return c.Models[realName], realName, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Search through aliases to find the right config
|
|
||||||
for actual, config := range c.Models {
|
|
||||||
for _, alias := range config.Aliases {
|
|
||||||
if alias == modelName {
|
|
||||||
return config, actual, true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return ModelConfig{}, "", false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfig(path string) (*Config, error) {
|
func LoadConfig(path string) (*Config, error) {
|
||||||
@@ -60,6 +64,14 @@ func LoadConfig(path string) (*Config, error) {
|
|||||||
config.HealthCheckTimeout = 15
|
config.HealthCheckTimeout = 15
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Populate the aliases map
|
||||||
|
config.aliases = make(map[string]string)
|
||||||
|
for modelName, modelConfig := range config.Models {
|
||||||
|
for _, alias := range modelConfig.Aliases {
|
||||||
|
config.aliases[alias] = modelName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLoadConfig(t *testing.T) {
|
func TestConfig_Load(t *testing.T) {
|
||||||
// Create a temporary YAML file for testing
|
// Create a temporary YAML file for testing
|
||||||
tempDir, err := os.MkdirTemp("", "test-config")
|
tempDir, err := os.MkdirTemp("", "test-config")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -17,7 +17,8 @@ func TestLoadConfig(t *testing.T) {
|
|||||||
defer os.RemoveAll(tempDir)
|
defer os.RemoveAll(tempDir)
|
||||||
|
|
||||||
tempFile := filepath.Join(tempDir, "config.yaml")
|
tempFile := filepath.Join(tempDir, "config.yaml")
|
||||||
content := `models:
|
content := `
|
||||||
|
models:
|
||||||
model1:
|
model1:
|
||||||
cmd: path/to/cmd --arg1 one
|
cmd: path/to/cmd --arg1 one
|
||||||
proxy: "http://localhost:8080"
|
proxy: "http://localhost:8080"
|
||||||
@@ -28,7 +29,17 @@ func TestLoadConfig(t *testing.T) {
|
|||||||
- "VAR1=value1"
|
- "VAR1=value1"
|
||||||
- "VAR2=value2"
|
- "VAR2=value2"
|
||||||
checkEndpoint: "/health"
|
checkEndpoint: "/health"
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --arg1 one
|
||||||
|
proxy: "http://localhost:8081"
|
||||||
|
aliases:
|
||||||
|
- "m2"
|
||||||
|
checkEndpoint: "/"
|
||||||
healthCheckTimeout: 15
|
healthCheckTimeout: 15
|
||||||
|
profiles:
|
||||||
|
test:
|
||||||
|
- model1
|
||||||
|
- model2
|
||||||
`
|
`
|
||||||
|
|
||||||
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
if err := os.WriteFile(tempFile, []byte(content), 0644); err != nil {
|
||||||
@@ -50,14 +61,33 @@ healthCheckTimeout: 15
|
|||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
},
|
},
|
||||||
|
"model2": {
|
||||||
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
|
Proxy: "http://localhost:8081",
|
||||||
|
Aliases: []string{"m2"},
|
||||||
|
Env: nil,
|
||||||
|
CheckEndpoint: "/",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, expected, config)
|
assert.Equal(t, expected, config)
|
||||||
|
|
||||||
|
realname, found := config.RealModelName("m1")
|
||||||
|
assert.True(t, found)
|
||||||
|
assert.Equal(t, "model1", realname)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestModelConfigSanitizedCommand(t *testing.T) {
|
func TestConfig_ModelConfigSanitizedCommand(t *testing.T) {
|
||||||
config := &ModelConfig{
|
config := &ModelConfig{
|
||||||
Cmd: `python model1.py \
|
Cmd: `python model1.py \
|
||||||
--arg1 value1 \
|
--arg1 value1 \
|
||||||
@@ -69,7 +99,10 @@ func TestModelConfigSanitizedCommand(t *testing.T) {
|
|||||||
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
assert.Equal(t, []string{"python", "model1.py", "--arg1", "value1", "--arg2", "value2"}, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestFindConfig(t *testing.T) {
|
func TestConfig_FindConfig(t *testing.T) {
|
||||||
|
|
||||||
|
// TODO?
|
||||||
|
// make make this shared between the different tests
|
||||||
config := &Config{
|
config := &Config{
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
@@ -88,6 +121,11 @@ func TestFindConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 10,
|
HealthCheckTimeout: 10,
|
||||||
|
aliases: map[string]string{
|
||||||
|
"m1": "model1",
|
||||||
|
"model-one": "model1",
|
||||||
|
"m2": "model2",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test finding a model by its name
|
// Test finding a model by its name
|
||||||
@@ -109,7 +147,7 @@ func TestFindConfig(t *testing.T) {
|
|||||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSanitizeCommand(t *testing.T) {
|
func TestConfig_SanitizeCommand(t *testing.T) {
|
||||||
// Test a simple command
|
// Test a simple command
|
||||||
args, err := SanitizeCommand("python model1.py")
|
args, err := SanitizeCommand("python model1.py")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|||||||
58
proxy/helpers_test.go
Normal file
58
proxy/helpers_test.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
nextTestPort int = 12000
|
||||||
|
portMutex sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check if the binary exists
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
binaryPath := getSimpleResponderPath()
|
||||||
|
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
||||||
|
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
m.Run()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to get the binary path
|
||||||
|
func getSimpleResponderPath() string {
|
||||||
|
goos := runtime.GOOS
|
||||||
|
goarch := runtime.GOARCH
|
||||||
|
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||||
|
portMutex.Lock()
|
||||||
|
defer portMutex.Unlock()
|
||||||
|
|
||||||
|
port := nextTestPort
|
||||||
|
nextTestPort++
|
||||||
|
|
||||||
|
return getTestSimpleResponderConfigPort(expectedMessage, port)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelConfig {
|
||||||
|
binaryPath := getSimpleResponderPath()
|
||||||
|
|
||||||
|
// Create a process configuration
|
||||||
|
return ModelConfig{
|
||||||
|
Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, port, expectedMessage),
|
||||||
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -113,7 +113,7 @@ func (p *Process) Stop() {
|
|||||||
p.Lock()
|
p.Lock()
|
||||||
defer p.Unlock()
|
defer p.Unlock()
|
||||||
|
|
||||||
if !p.isRunning {
|
if !p.isRunning || p.cmd == nil || p.cmd.Process == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,54 +1,15 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check if the binary exists
|
|
||||||
func TestMain(m *testing.M) {
|
|
||||||
binaryPath := getBinaryPath()
|
|
||||||
if _, err := os.Stat(binaryPath); os.IsNotExist(err) {
|
|
||||||
fmt.Printf("simple-responder not found at %s, did you `make simple-responder`?\n", binaryPath)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
m.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to get the binary path
|
|
||||||
func getBinaryPath() string {
|
|
||||||
goos := runtime.GOOS
|
|
||||||
goarch := runtime.GOARCH
|
|
||||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
|
||||||
// Define the range
|
|
||||||
min := 12000
|
|
||||||
max := 13000
|
|
||||||
|
|
||||||
// Generate a random number between 12000 and 13000
|
|
||||||
randomPort := rand.Intn(max-min+1) + min
|
|
||||||
binaryPath := getBinaryPath()
|
|
||||||
|
|
||||||
// Create a process configuration
|
|
||||||
return ModelConfig{
|
|
||||||
Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, randomPort, expectedMessage),
|
|
||||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", randomPort),
|
|
||||||
CheckEndpoint: "/health",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||||
logMonitor := NewLogMonitorWriter(io.Discard)
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
expectedMessage := "testing91931"
|
expectedMessage := "testing91931"
|
||||||
@@ -56,7 +17,9 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
|||||||
|
|
||||||
// Create a process
|
// Create a process
|
||||||
process := NewProcess("test-process", 5, config, logMonitor)
|
process := NewProcess("test-process", 5, config, logMonitor)
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
defer process.Stop()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// process is automatically started
|
// process is automatically started
|
||||||
@@ -92,6 +55,8 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
process := NewProcess("broken", 1, config, NewLogMonitor())
|
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
@@ -99,6 +64,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
|
|||||||
assert.Contains(t, w.Body.String(), "unable to start process")
|
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test that the process unloads after the TTL
|
||||||
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skipping long auto unload TTL test")
|
t.Skip("skipping long auto unload TTL test")
|
||||||
@@ -111,7 +77,9 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
|
|||||||
assert.Equal(t, 3, config.UnloadAfter)
|
assert.Equal(t, 3, config.UnloadAfter)
|
||||||
|
|
||||||
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
defer process.Stop()
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// Proxy the request (auto start)
|
// Proxy the request (auto start)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -16,18 +17,18 @@ import (
|
|||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
config *Config
|
config *Config
|
||||||
currentProcess *Process
|
currentProcesses map[string]*Process
|
||||||
logMonitor *LogMonitor
|
logMonitor *LogMonitor
|
||||||
ginEngine *gin.Engine
|
ginEngine *gin.Engine
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(config *Config) *ProxyManager {
|
func New(config *Config) *ProxyManager {
|
||||||
pm := &ProxyManager{
|
pm := &ProxyManager{
|
||||||
config: config,
|
config: config,
|
||||||
currentProcess: nil,
|
currentProcesses: make(map[string]*Process),
|
||||||
logMonitor: NewLogMonitor(),
|
logMonitor: NewLogMonitor(),
|
||||||
ginEngine: gin.New(),
|
ginEngine: gin.New(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up routes using the Gin engine
|
// Set up routes using the Gin engine
|
||||||
@@ -43,7 +44,7 @@ func New(config *Config) *ProxyManager {
|
|||||||
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
||||||
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
||||||
|
|
||||||
pm.ginEngine.NoRoute(pm.proxyRequestHandler)
|
pm.ginEngine.NoRoute(pm.proxyNoRouteHandler)
|
||||||
|
|
||||||
// Disable console color for testing
|
// Disable console color for testing
|
||||||
gin.DisableConsoleColor()
|
gin.DisableConsoleColor()
|
||||||
@@ -59,6 +60,21 @@ func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
|
|||||||
pm.ginEngine.ServeHTTP(w, r)
|
pm.ginEngine.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) StopProcesses() {
|
||||||
|
pm.Lock()
|
||||||
|
defer pm.Unlock()
|
||||||
|
|
||||||
|
pm.stopProcesses()
|
||||||
|
}
|
||||||
|
|
||||||
|
// for internal usage
|
||||||
|
func (pm *ProxyManager) stopProcesses() {
|
||||||
|
for _, process := range pm.currentProcesses {
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
|
pm.currentProcesses = make(map[string]*Process)
|
||||||
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||||
data := []interface{}{}
|
data := []interface{}{}
|
||||||
for id := range pm.config.Models {
|
for id := range pm.config.Models {
|
||||||
@@ -80,27 +96,64 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) swapModel(requestedModel string) error {
|
func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
||||||
pm.Lock()
|
pm.Lock()
|
||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
// find the model configuration matching requestedModel
|
// Check if requestedModel contains a /
|
||||||
modelConfig, modelID, found := pm.config.FindConfig(requestedModel)
|
groupName, modelName := "", requestedModel
|
||||||
if !found {
|
if idx := strings.Index(requestedModel, "/"); idx != -1 {
|
||||||
return fmt.Errorf("could not find configuration for %s", requestedModel)
|
groupName = requestedModel[:idx]
|
||||||
|
modelName = requestedModel[idx+1:]
|
||||||
}
|
}
|
||||||
|
|
||||||
// do nothing as it's already the correct process
|
if groupName != "" {
|
||||||
if pm.currentProcess != nil {
|
if _, found := pm.config.Profiles[groupName]; !found {
|
||||||
if pm.currentProcess.ID == modelID {
|
return nil, fmt.Errorf("model group not found %s", groupName)
|
||||||
return nil
|
|
||||||
} else {
|
|
||||||
pm.currentProcess.Stop()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.currentProcess = NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
// de-alias the real model name and get a real one
|
||||||
return nil
|
realModelName, found := pm.config.RealModelName(modelName)
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("could not find modelID for %s", requestedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// exit early when already running, otherwise stop everything and swap
|
||||||
|
requestedProcessKey := groupName + "/" + realModelName
|
||||||
|
if process, found := pm.currentProcesses[requestedProcessKey]; found {
|
||||||
|
return process, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop all running models
|
||||||
|
pm.stopProcesses()
|
||||||
|
|
||||||
|
if groupName == "" {
|
||||||
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("could not find configuration for %s", realModelName)
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
|
processKey := groupName + "/" + modelID
|
||||||
|
pm.currentProcesses[processKey] = process
|
||||||
|
} else {
|
||||||
|
for _, modelName := range pm.config.Profiles[groupName] {
|
||||||
|
if realModelName, found := pm.config.RealModelName(modelName); found {
|
||||||
|
modelConfig, modelID, found := pm.config.FindConfig(realModelName)
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, groupName)
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
|
processKey := groupName + "/" + modelID
|
||||||
|
pm.currentProcesses[processKey] = process
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// requestedProcessKey should exist due to swap
|
||||||
|
return pm.currentProcesses[requestedProcessKey], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
|
||||||
@@ -120,24 +173,27 @@ func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pm.swapModel(model); err != nil {
|
if process, err := pm.swapModel(model); err != nil {
|
||||||
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
|
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
|
||||||
return
|
return
|
||||||
|
} else {
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
|
c.Request.Header.Del("transfer-encoding")
|
||||||
|
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
|
||||||
|
process.ProxyRequest(c.Writer, c.Request)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
|
||||||
|
// since maps are unordered, just use the first available process if one exists
|
||||||
|
for _, process := range pm.currentProcesses {
|
||||||
|
process.ProxyRequest(c.Writer, c.Request)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
|
||||||
c.Request.Header.Del("transfer-encoding")
|
|
||||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
|
||||||
|
|
||||||
pm.currentProcess.ProxyRequest(c.Writer, c.Request)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyRequestHandler(c *gin.Context) {
|
|
||||||
if pm.currentProcess != nil {
|
|
||||||
pm.currentProcess.ProxyRequest(c.Writer, c.Request)
|
|
||||||
} else {
|
|
||||||
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
76
proxy/proxymanager_test.go
Normal file
76
proxy/proxymanager_test.go
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), modelName)
|
||||||
|
|
||||||
|
_, exists := proxy.currentProcesses["/"+modelName]
|
||||||
|
assert.True(t, exists, "expected %s key in currentProcesses", modelName)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure there's only one loaded model
|
||||||
|
assert.Len(t, proxy.currentProcesses, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"model1", "model2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
for modelID, requestedModel := range map[string]string{"model1": "test/model1", "model2": "test/model2"} {
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure there's two loaded models
|
||||||
|
assert.Len(t, proxy.currentProcesses, 2)
|
||||||
|
_, exists := proxy.currentProcesses["test/model1"]
|
||||||
|
assert.True(t, exists, "expected test/model1 key in currentProcesses")
|
||||||
|
|
||||||
|
_, exists = proxy.currentProcesses["test/model2"]
|
||||||
|
assert.True(t, exists, "expected test/model2 key in currentProcesses")
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user