From 533162ce6af65e5cd67fba573a311d0855f0d858 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Tue, 19 Nov 2024 16:32:51 -0800 Subject: [PATCH] add support for automatically unloading a model (#10) (#14) * Make starting upstream process on-demand (#10) * Add automatic unload of model after TTL is reached * add `ttl` configuration parameter to models in seconds, default is 0 (never unload) --- Makefile | 3 + README.md | 5 ++ config.example.yaml | 3 + misc/simple-responder/simple-responder.go | 4 + proxy/config.go | 1 + proxy/process.go | 86 +++++++++++++++----- proxy/process_test.go | 97 +++++++++++++++-------- proxy/proxymanager.go | 4 +- 8 files changed, 149 insertions(+), 54 deletions(-) diff --git a/Makefile b/Makefile index c7d0d1e..47c8999 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,9 @@ clean: rm -rf $(BUILD_DIR) test: + go test -short -v ./proxy + +test-all: go test -v ./proxy # Build OSX binary diff --git a/README.md b/README.md index 8dfaa47..4b232d8 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,11 @@ models: # until the server is ready checkEndpoint: /custom-endpoint + # automatically unload the model after 10 seconds + # ttl values must be a value greater than 0 + # default: 0 = never unload model + ttl: 5 + "qwen": # environment variables to pass to the command env: diff --git a/config.example.yaml b/config.example.yaml index e28a941..e660b11 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -17,6 +17,9 @@ models: # check this path for a HTTP 200 response for the server to be ready checkEndpoint: /health + # unload model after 5 seconds + ttl: 5 + "qwen": cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf proxy: http://127.0.0.1:8999 diff --git a/misc/simple-responder/simple-responder.go b/misc/simple-responder/simple-responder.go index a7a1fea..61fe3f0 100644 --- a/misc/simple-responder/simple-responder.go +++ b/misc/simple-responder/simple-responder.go @@ -20,7 +20,11 @@ func main() { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { // Set the header to text/plain w.Header().Set("Content-Type", "text/plain") + fmt.Fprintln(w, *responseMessage) + }) + http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") fmt.Fprintln(w, *responseMessage) // Get environment variables diff --git a/proxy/config.go b/proxy/config.go index 37aac8f..ae81150 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -14,6 +14,7 @@ type ModelConfig struct { Aliases []string `yaml:"aliases"` Env []string `yaml:"env"` CheckEndpoint string `yaml:"checkEndpoint"` + UnloadAfter int `yaml:"ttl"` } func (m *ModelConfig) SanitizedCommand() ([]string, error) { diff --git a/proxy/process.go b/proxy/process.go index 8a8e465..f4d30c6 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -17,27 +17,33 @@ import ( type Process struct { sync.Mutex - ID string - config ModelConfig - cmd *exec.Cmd - logMonitor *LogMonitor + ID string + config ModelConfig + cmd *exec.Cmd + logMonitor *LogMonitor + healthCheckTimeout int + + isRunning bool + lastRequestHandled time.Time } -func NewProcess(ID string, config ModelConfig, logMonitor *LogMonitor) *Process { +func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process { return &Process{ - ID: ID, - config: config, - cmd: nil, - logMonitor: logMonitor, + ID: ID, + config: config, + cmd: nil, + logMonitor: logMonitor, + healthCheckTimeout: healthCheckTimeout, } } -func (p *Process) Start(healthCheckTimeout int) error { +// start the process and check it for errors +func (p *Process) start() error { p.Lock() defer p.Unlock() - if p.cmd != nil { - return fmt.Errorf("process already started") + if p.isRunning { + return fmt.Errorf("process already running") } args, err := p.config.SanitizedCommand() @@ -51,6 +57,8 @@ func (p *Process) Start(healthCheckTimeout int) error { p.cmd.Env = p.config.Env err = p.cmd.Start() + p.isRunning = true + if err != nil { return err } @@ -58,7 +66,8 @@ func (p *Process) Start(healthCheckTimeout int) error { // watch for the command to exit cmdCtx, cancel := context.WithCancelCause(context.Background()) - // monitor the command's exit status + // monitor the command's exit status. Usually this happens if + // the process exited unexpectedly go func() { err := p.cmd.Wait() if err != nil { @@ -66,13 +75,37 @@ func (p *Process) Start(healthCheckTimeout int) error { } else { cancel(nil) } + + p.isRunning = false }() + // wait a bit for process to start before checking the health endpoint + time.Sleep(250 * time.Millisecond) + // wait for checkHealthEndpoint - if err := p.checkHealthEndpoint(cmdCtx, healthCheckTimeout); err != nil { + if err := p.checkHealthEndpoint(cmdCtx); err != nil { return err } + if p.config.UnloadAfter > 0 { + // start a goroutine to check every second if + // the process should be stopped + go func() { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + maxDuration := time.Duration(p.config.UnloadAfter) * time.Second + + for { + <-ticker.C + if time.Since(p.lastRequestHandled) > maxDuration { + fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter) + p.Stop() + return + } + } + }() + } + return nil } @@ -80,15 +113,20 @@ func (p *Process) Stop() { p.Lock() defer p.Unlock() - if p.cmd == nil { + if !p.isRunning { return } p.cmd.Process.Signal(syscall.SIGTERM) p.cmd.Process.Wait() + p.isRunning = false } -func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout int) error { +func (p *Process) IsRunning() bool { + return p.isRunning +} + +func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error { if p.config.Proxy == "" { return fmt.Errorf("no upstream available to check /health") } @@ -105,7 +143,7 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout } proxyTo := p.config.Proxy - maxDuration := time.Second * time.Duration(healthCheckTimeout) + maxDuration := time.Second * time.Duration(p.healthCheckTimeout) healthURL, err := url.JoinPath(proxyTo, checkEndpoint) if err != nil { return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint) @@ -115,7 +153,6 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout startTime := time.Now() for { - time.Sleep(time.Second) req, err := http.NewRequest("GET", healthURL, nil) if err != nil { return err @@ -162,15 +199,22 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout if ttl < 0 { return fmt.Errorf("failed to check health from: %s", healthURL) } + + time.Sleep(time.Second) } } func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { - if p.cmd == nil { - http.Error(w, "process not started", http.StatusInternalServerError) - return + if !p.isRunning { + if err := p.start(); err != nil { + errstr := fmt.Sprintf("unable to start process: %s", err) + http.Error(w, errstr, http.StatusInternalServerError) + return + } } + p.lastRequestHandled = time.Now() + proxyTo := p.config.Proxy client := &http.Client{} req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body) diff --git a/proxy/process_test.go b/proxy/process_test.go index eab3e5a..34221c8 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -2,14 +2,17 @@ package proxy import ( "fmt" + "io" "math/rand" "net/http" "net/http/httptest" "os" "path/filepath" "runtime" - "strings" "testing" + "time" + + "github.com/stretchr/testify/assert" ) // Check if the binary exists @@ -29,53 +32,40 @@ func getBinaryPath() string { return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch)) } -func TestProcess_ProcessStartStop(t *testing.T) { +func getTestSimpleResponderConfig(expectedMessage string) ModelConfig { // Define the range min := 12000 max := 13000 // Generate a random number between 12000 and 13000 randomPort := rand.Intn(max-min+1) + min - binaryPath := getBinaryPath() - // Create a log monitor - logMonitor := NewLogMonitor() - - expectedMessage := "testing91931" - // Create a process configuration - config := ModelConfig{ + return ModelConfig{ Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, randomPort, expectedMessage), Proxy: fmt.Sprintf("http://127.0.0.1:%d", randomPort), CheckEndpoint: "/health", } +} + +func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { + logMonitor := NewLogMonitorWriter(io.Discard) + expectedMessage := "testing91931" + config := getTestSimpleResponderConfig(expectedMessage) // Create a process - process := NewProcess("test-process", config, logMonitor) - - // Start the process - t.Logf("Starting %s on port %d", binaryPath, randomPort) - err := process.Start(5) - if err != nil { - t.Fatalf("Failed to start process: %v", err) - } - - // Create a test request + process := NewProcess("test-process", 5, config, logMonitor) req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() - // Proxy the request + // process is automatically started + assert.False(t, process.IsRunning()) process.ProxyRequest(w, req) + assert.True(t, process.IsRunning()) - // Check the response - if w.Code != http.StatusOK { - t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) - } - - if !strings.Contains(w.Body.String(), expectedMessage) { - t.Errorf("Expected body to contain '%s', got %q", expectedMessage, w.Body.String()) - } + assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), expectedMessage) // Stop the process process.Stop() @@ -86,8 +76,53 @@ func TestProcess_ProcessStartStop(t *testing.T) { // Proxy the request process.ProxyRequest(w, req) - // Check the response - if w.Code == http.StatusInternalServerError { - t.Errorf("Expected status code %d, got %d", http.StatusInternalServerError, w.Code) + // should have automatically started the process again + if w.Code != http.StatusOK { + t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code) } } + +// test that the automatic start returns the expected error type +func TestProcess_BrokenModelConfig(t *testing.T) { + // Create a process configuration + config := ModelConfig{ + Cmd: "nonexistant-command", + Proxy: "http://127.0.0.1:9913", + CheckEndpoint: "/health", + } + + process := NewProcess("broken", 1, config, NewLogMonitor()) + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + process.ProxyRequest(w, req) + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Contains(t, w.Body.String(), "unable to start process") +} + +func TestProcess_UnloadAfterTTL(t *testing.T) { + if testing.Short() { + t.Skip("skipping long auto unload TTL test") + } + + expectedMessage := "I_sense_imminent_danger" + config := getTestSimpleResponderConfig(expectedMessage) + assert.Equal(t, 0, config.UnloadAfter) + config.UnloadAfter = 3 // seconds + assert.Equal(t, 3, config.UnloadAfter) + + process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard)) + req := httptest.NewRequest("GET", "/", nil) + w := httptest.NewRecorder() + + // Proxy the request (auto start) + process.ProxyRequest(w, req) + assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), expectedMessage) + + assert.True(t, process.IsRunning()) + + // wait 5 seconds + time.Sleep(5 * time.Second) + + assert.False(t, process.IsRunning()) +} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 7e07690..601fa6e 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -99,8 +99,8 @@ func (pm *ProxyManager) swapModel(requestedModel string) error { } } - pm.currentProcess = NewProcess(modelID, modelConfig, pm.logMonitor) - return pm.currentProcess.Start(pm.config.HealthCheckTimeout) + pm.currentProcess = NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor) + return nil } func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {