From 09bdd86b548e3829a0a18baa81a8ece6aa980428 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Wed, 5 Feb 2025 17:19:59 -0800 Subject: [PATCH] Improve shutdown behaviour (#47) (#49) Introduce `Process.Shutdown()` and `ProxyManager.Shutdown()`. These two function required a lot of internal process state management refactoring. A key benefit is that `Process.start()` is now interruptable. When `Shutdown()` is called it will break the long health check loop. State management within Process is also improved. Added `starting`, `stopping` and `shutdown` states. Additionally, introduced a simple finite state machine to manage transitions. --- llama-swap.go | 2 +- proxy/process.go | 300 +++++++++++++++++++++++++------------ proxy/process_test.go | 117 ++++++++++++++- proxy/proxymanager.go | 27 +++- proxy/proxymanager_test.go | 50 +++++++ 5 files changed, 398 insertions(+), 98 deletions(-) diff --git a/llama-swap.go b/llama-swap.go index f7e6b7b..7df4336 100644 --- a/llama-swap.go +++ b/llama-swap.go @@ -47,7 +47,7 @@ func main() { go func() { <-sigChan fmt.Println("Shutting down llama-swap") - proxyManager.StopProcesses() + proxyManager.Shutdown() os.Exit(0) }() diff --git a/proxy/process.go b/proxy/process.go index 161e09b..65354da 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -2,6 +2,7 @@ package proxy import ( "context" + "errors" "fmt" "io" "net/http" @@ -16,14 +17,19 @@ import ( type ProcessState string const ( - StateStopped ProcessState = ProcessState("stopped") - StateReady ProcessState = ProcessState("ready") - StateFailed ProcessState = ProcessState("failed") + StateStopped ProcessState = ProcessState("stopped") + StateStarting ProcessState = ProcessState("starting") + StateReady ProcessState = ProcessState("ready") + StateStopping ProcessState = ProcessState("stopping") + + // failed a health check on start and will not be recovered + StateFailed ProcessState = ProcessState("failed") + + // process is shutdown and will not be restarted + StateShutdown ProcessState = ProcessState("shutdown") ) type Process struct { - sync.Mutex - ID string config ModelConfig cmd *exec.Cmd @@ -36,9 +42,17 @@ type Process struct { state ProcessState inFlightRequests sync.WaitGroup + + // used to block on multiple start() calls + waitStarting sync.WaitGroup + + // for managing shutdown state + shutdownCtx context.Context + shutdownCancel context.CancelFunc } func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process { + ctx, cancel := context.WithCancel(context.Background()) return &Process{ ID: ID, config: config, @@ -46,22 +60,88 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito logMonitor: logMonitor, healthCheckTimeout: healthCheckTimeout, state: StateStopped, + shutdownCtx: ctx, + shutdownCancel: cancel, } } -// start the process and returns when it is ready +func (p *Process) setState(newState ProcessState) error { + // enforce valid state transitions + invalidTransition := false + if p.state == StateStopped { + // stopped -> starting + if newState != StateStarting { + invalidTransition = true + } + } else if p.state == StateStarting { + // starting -> ready | failed | stopping + if newState != StateReady && newState != StateFailed && newState != StateStopping { + invalidTransition = true + } + } else if p.state == StateReady { + // ready -> stopping + if newState != StateStopping { + invalidTransition = true + } + } else if p.state == StateStopping { + // stopping -> stopped | shutdown + if newState != StateStopped && newState != StateShutdown { + invalidTransition = true + } + } else if p.state == StateFailed || p.state == StateShutdown { + invalidTransition = true + } + + if invalidTransition { + //panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState)) + return fmt.Errorf("invalid state transition from %s to %s", p.state, newState) + } + + p.state = newState + return nil +} + +func (p *Process) CurrentState() ProcessState { + p.stateMutex.RLock() + defer p.stateMutex.RUnlock() + return p.state +} + +// start starts the upstream command, checks the health endpoint, and sets the state to Ready +// it is a private method because starting is automatic but stopping can be called +// at any time. func (p *Process) start() error { + if p.config.Proxy == "" { + return fmt.Errorf("can not start(), upstream proxy missing") + } + + // wait for the other start() to complete + curState := p.CurrentState() + + if curState == StateReady { + return nil + } + + if curState == StateStarting { + p.waitStarting.Wait() + + if state := p.CurrentState(); state != StateReady { + return fmt.Errorf("start() failed current state: %v", state) + } + + return nil + } + p.stateMutex.Lock() defer p.stateMutex.Unlock() - if p.state == StateReady { - return nil + if err := p.setState(StateStarting); err != nil { + return err } - if p.state == StateFailed { - return fmt.Errorf("process is in a failed state and can not be restarted") - } + p.waitStarting.Add(1) + defer p.waitStarting.Done() args, err := p.config.SanitizedCommand() if err != nil { @@ -76,7 +156,8 @@ func (p *Process) start() error { err = p.cmd.Start() if err != nil { - return err + p.setState(StateFailed) + return fmt.Errorf("start() failed: %v", err) } // One of three things can happen at this stage: @@ -87,9 +168,55 @@ func (p *Process) start() error { // only in the third case will the process be considered Ready to accept <-time.After(250 * time.Millisecond) // give process a bit of time to start - if err := p.checkHealthEndpoint(); err != nil { - p.state = StateFailed - return err + checkStartTime := time.Now() + maxDuration := time.Second * time.Duration(p.healthCheckTimeout) + checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint) + + // a "none" means don't check for health ... I could have picked a better word :facepalm: + if checkEndpoint != "none" { + // keep default behaviour + if checkEndpoint == "" { + checkEndpoint = "/health" + } + + proxyTo := p.config.Proxy + healthURL, err := url.JoinPath(proxyTo, checkEndpoint) + if err != nil { + return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint) + } + + checkDeadline, cancelHealthCheck := context.WithDeadline( + context.Background(), + checkStartTime.Add(maxDuration), + ) + defer cancelHealthCheck() + + // Health check loop + loop: + for { + select { + case <-checkDeadline.Done(): + p.setState(StateFailed) + return fmt.Errorf("health check failed after %vs", maxDuration.Seconds()) + case <-p.shutdownCtx.Done(): + return errors.New("health check interrupted due to shutdown") + default: + if err := p.checkHealthEndpoint(healthURL); err == nil { + cancelHealthCheck() + break loop + } else { + if strings.Contains(err.Error(), "connection refused") { + endTime, _ := checkDeadline.Deadline() + ttl := time.Until(endTime) + fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds()) + } else { + fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err) + } + } + } + + <-time.After(time.Second) + } } if p.config.UnloadAfter > 0 { @@ -115,37 +242,63 @@ func (p *Process) start() error { }() } - p.state = StateReady - return nil + return p.setState(StateReady) } func (p *Process) Stop() { // wait for any inflight requests before proceeding p.inFlightRequests.Wait() - p.stateMutex.Lock() defer p.stateMutex.Unlock() - if p.state != StateReady { - fmt.Fprintf(p.logMonitor, "!!! Info - Stop() called but Process State is not READY\n") + // calling Stop() when state is invalid is a no-op + if err := p.setState(StateStopping); err != nil { + fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err) return } - if p.cmd == nil || p.cmd.Process == nil { - // this situation should never happen... but if it does just update the state - fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.\n") - p.state = StateStopped - return - } + // stop the process with a graceful exit timeout + p.stopCommand(5 * time.Second) - sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + if err := p.setState(StateStopped); err != nil { + panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err)) + } +} + +// Shutdown is called when llama-swap is shutting down. It will give a little bit +// of time for any inflight requests to complete before shutting down. If the Process +// is in the state of starting, it will cancel it and shut it down +func (p *Process) Shutdown() { + // cancel anything that can be interrupted by a shutdown (ie: healthcheck) + p.shutdownCancel() + + p.stateMutex.Lock() + defer p.stateMutex.Unlock() + p.setState(StateStopping) + + // 5 seconds to stop the process + p.stopCommand(5 * time.Second) + if err := p.setState(StateShutdown); err != nil { + fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err) + } + p.setState(StateShutdown) +} + +// stopCommand will send a SIGTERM to the process and wait for it to exit. +// If it does not exit within 5 seconds, it will send a SIGKILL. +func (p *Process) stopCommand(sigtermTTL time.Duration) { + sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL) + defer cancelTimeout() sigtermNormal := make(chan error, 1) go func() { sigtermNormal <- p.cmd.Wait() }() + if p.cmd == nil || p.cmd.Process == nil { + panic("this should not happen, cmd or cmd.Process is nil") + } + p.cmd.Process.Signal(syscall.SIGTERM) select { @@ -170,94 +323,53 @@ func (p *Process) Stop() { } } } - - p.state = StateStopped } -func (p *Process) CurrentState() ProcessState { - p.stateMutex.RLock() - defer p.stateMutex.RUnlock() - return p.state -} +func (p *Process) checkHealthEndpoint(healthURL string) error { -func (p *Process) checkHealthEndpoint() error { - if p.config.Proxy == "" { - return fmt.Errorf("no upstream available to check /health") + client := &http.Client{ + Timeout: 500 * time.Millisecond, } - checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint) - - if checkEndpoint == "none" { - return nil - } - - // keep default behaviour - if checkEndpoint == "" { - checkEndpoint = "/health" - } - - proxyTo := p.config.Proxy - maxDuration := time.Second * time.Duration(p.healthCheckTimeout) - healthURL, err := url.JoinPath(proxyTo, checkEndpoint) + req, err := http.NewRequest("GET", healthURL, nil) if err != nil { - return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint) + return err } - client := &http.Client{} - startTime := time.Now() - - for { - req, err := http.NewRequest("GET", healthURL, nil) - if err != nil { - return err - } - - resp, err := client.Do(req) - - ttl := (maxDuration - time.Since(startTime)).Seconds() - - if err != nil { - // wait a bit longer for TCP connection issues - if strings.Contains(err.Error(), "connection refused") { - fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl) - time.Sleep(5 * time.Second) - } else { - time.Sleep(time.Second) - } - - if ttl < 0 { - return fmt.Errorf("failed to check health from: %s", healthURL) - } - - continue - } - - defer resp.Body.Close() - if resp.StatusCode == http.StatusOK { - return nil - } - - if ttl < 0 { - return fmt.Errorf("failed to check health from: %s", healthURL) - } - - time.Sleep(time.Second) + resp, err := client.Do(req) + if err != nil { + return err } + defer resp.Body.Close() + + // got a response but it was not an OK + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("status code: %d", resp.StatusCode) + } + + return nil } func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { - p.inFlightRequests.Add(1) + // prevent new requests from being made while stopping or irrecoverable + currentState := p.CurrentState() + if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping { + http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable) + return + } + p.inFlightRequests.Add(1) defer func() { p.lastRequestHandled = time.Now() p.inFlightRequests.Done() }() + // start the process on demand if p.CurrentState() != StateReady { if err := p.start(); err != nil { errstr := fmt.Sprintf("unable to start process: %s", err) - http.Error(w, errstr, http.StatusInternalServerError) + http.Error(w, errstr, http.StatusBadGateway) return } } diff --git a/proxy/process_test.go b/proxy/process_test.go index 0fdfdec..0d7d3f8 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -48,6 +48,33 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { } } +// TestProcess_WaitOnMultipleStarts tests that multiple concurrent requests +// are all handled successfully, even though they all may ask for the process to .start() +func TestProcess_WaitOnMultipleStarts(t *testing.T) { + + logMonitor := NewLogMonitorWriter(io.Discard) + expectedMessage := "testing91931" + config := getTestSimpleResponderConfig(expectedMessage) + + process := NewProcess("test-process", 5, config, logMonitor) + defer process.Stop() + + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func(reqID int) { + defer wg.Done() + req := httptest.NewRequest("GET", "/test", nil) + w := httptest.NewRecorder() + process.ProxyRequest(w, req) + assert.Equal(t, http.StatusOK, w.Code, "Worker %d got wrong HTTP code", reqID) + assert.Contains(t, w.Body.String(), expectedMessage, "Worker %d got wrong message", reqID) + }(i) + } + wg.Wait() + assert.Equal(t, StateReady, process.CurrentState()) +} + // test that the automatic start returns the expected error type func TestProcess_BrokenModelConfig(t *testing.T) { // Create a process configuration @@ -58,13 +85,17 @@ func TestProcess_BrokenModelConfig(t *testing.T) { } process := NewProcess("broken", 1, config, NewLogMonitor()) - defer process.Stop() req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() process.ProxyRequest(w, req) - assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, http.StatusBadGateway, w.Code) assert.Contains(t, w.Body.String(), "unable to start process") + + w = httptest.NewRecorder() + process.ProxyRequest(w, req) + assert.Equal(t, http.StatusServiceUnavailable, w.Code) + assert.Contains(t, w.Body.String(), "Process can not ProxyRequest, state is failed") } func TestProcess_UnloadAfterTTL(t *testing.T) { @@ -190,3 +221,85 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { assert.Equal(t, key, result) } } + +func TestSetState(t *testing.T) { + tests := []struct { + name string + currentState ProcessState + newState ProcessState + expectedError error + expectedResult ProcessState + }{ + {"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting}, + {"Starting to Ready", StateStarting, StateReady, nil, StateReady}, + {"Starting to Failed", StateStarting, StateFailed, nil, StateFailed}, + {"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping}, + {"Ready to Stopping", StateReady, StateStopping, nil, StateStopping}, + {"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped}, + {"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown}, + {"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped}, + {"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting}, + {"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady}, + {"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady}, + {"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping}, + {"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed}, + {"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed}, + {"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown}, + {"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + p := &Process{ + state: test.currentState, + } + + err := p.setState(test.newState) + if err != nil && test.expectedError == nil { + t.Errorf("Unexpected error: %v", err) + } else if err == nil && test.expectedError != nil { + t.Errorf("Expected error: %v, but got none", test.expectedError) + } else if err != nil && test.expectedError != nil { + if err.Error() != test.expectedError.Error() { + t.Errorf("Expected error: %v, got: %v", test.expectedError, err) + } + } + + if p.state != test.expectedResult { + t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state) + } + }) + } +} + +func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) { + if testing.Short() { + t.Skip("skipping long shutdown test") + } + + logMonitor := NewLogMonitorWriter(io.Discard) + expectedMessage := "testing91931" + + // make a config where the healthcheck will always fail because port is wrong + config := getTestSimpleResponderConfigPort(expectedMessage, 9999) + config.Proxy = "http://localhost:9998/test" + + healthCheckTTLSeconds := 30 + process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor) + + // start a goroutine to simulate a shutdown + var wg sync.WaitGroup + go func() { + defer wg.Done() + <-time.After(time.Second * 2) + process.Shutdown() + }() + wg.Add(1) + + // start the process, this is a blocking call + err := process.start() + + wg.Wait() + assert.ErrorContains(t, err, "health check interrupted due to shutdown") + assert.Equal(t, StateShutdown, process.CurrentState()) +} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index edad17a..d81c716 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -156,13 +156,38 @@ func (pm *ProxyManager) stopProcesses() { return } + // stop Processes in parallel + var wg sync.WaitGroup for _, process := range pm.currentProcesses { - process.Stop() + wg.Add(1) + go func(process *Process) { + defer wg.Done() + process.Stop() + }(process) } + wg.Wait() pm.currentProcesses = make(map[string]*Process) } +// Shutdown is called to shutdown all upstream processes +// when llama-swap is shutting down. +func (pm *ProxyManager) Shutdown() { + pm.Lock() + defer pm.Unlock() + + // shutdown process in parallel + var wg sync.WaitGroup + for _, process := range pm.currentProcesses { + wg.Add(1) + go func(process *Process) { + defer wg.Done() + process.Shutdown() + }(process) + } + wg.Wait() +} + func (pm *ProxyManager) listModelsHandler(c *gin.Context) { data := []interface{}{} for id, modelConfig := range pm.config.Models { diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index e1476ef..f15e5f7 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -254,3 +254,53 @@ func TestProxyManager_ProfileNonMember(t *testing.T) { assert.Equal(t, http.StatusNotFound, w.Code) } } + +func TestProxyManager_Shutdown(t *testing.T) { + // make broken model configurations + model1Config := getTestSimpleResponderConfigPort("model1", 9991) + model1Config.Proxy = "http://localhost:10001/" + + model2Config := getTestSimpleResponderConfigPort("model2", 9992) + model2Config.Proxy = "http://localhost:10002/" + + model3Config := getTestSimpleResponderConfigPort("model3", 9993) + model3Config.Proxy = "http://localhost:10003/" + + config := &Config{ + HealthCheckTimeout: 15, + Profiles: map[string][]string{ + "test": {"model1", "model2", "model3"}, + }, + Models: map[string]ModelConfig{ + "model1": model1Config, + "model2": model2Config, + "model3": model3Config, + }, + } + + proxy := New(config) + + // Start all the processes + var wg sync.WaitGroup + for _, modelName := range []string{"test:model1", "test:model2", "test:model3"} { + wg.Add(1) + go func(modelName string) { + defer wg.Done() + reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + // send a request to trigger the proxy to load + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusBadGateway, w.Code) + assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") + //fmt.Println(w.Code, w.Body.String()) + }(modelName) + } + + go func() { + <-time.After(time.Second) + proxy.Shutdown() + }() + wg.Wait() +}