diff --git a/proxy/process.go b/proxy/process.go index 92fd7f4..2169cf5 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -34,6 +34,9 @@ type Process struct { config ModelConfig cmd *exec.Cmd + // for p.cmd.Wait() select { ... } + cmdWaitChan chan error + processLogger *LogMonitor proxyLogger *LogMonitor @@ -61,6 +64,7 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo ID: ID, config: config, cmd: nil, + cmdWaitChan: make(chan error, 1), processLogger: processLogger, proxyLogger: proxyLogger, healthCheckTimeout: healthCheckTimeout, @@ -89,16 +93,17 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, defer p.stateMutex.Unlock() if p.state != expectedState { + p.proxyLogger.Warnf("swapState() Unexpected current state %s, expected %s", p.state, expectedState) return p.state, ErrExpectedStateMismatch } if !isValidTransition(p.state, newState) { - p.proxyLogger.Warnf("Invalid state transition from %s to %s", p.state, newState) + p.proxyLogger.Warnf("swapState() Invalid state transition from %s to %s", p.state, newState) return p.state, ErrInvalidStateTransition } - p.proxyLogger.Debugf("State transition from %s to %s", expectedState, newState) p.state = newState + p.proxyLogger.Debugf("swapState() State transitioned from %s to %s", expectedState, newState) return p.state, nil } @@ -179,6 +184,13 @@ func (p *Process) start() error { return fmt.Errorf("start() failed: %v", err) } + // Capture the exit error for later signaling + go func() { + exitErr := p.cmd.Wait() + p.proxyLogger.Debugf("cmd.Wait() returned for [%s] error: %v", p.ID, exitErr) + p.cmdWaitChan <- exitErr + }() + // One of three things can happen at this stage: // 1. The command exits unexpectedly // 2. The health check fails @@ -222,6 +234,22 @@ func (p *Process) start() error { } case <-p.shutdownCtx.Done(): return errors.New("health check interrupted due to shutdown") + case exitErr := <-p.cmdWaitChan: + if exitErr != nil { + p.proxyLogger.Warnf("upstream command exited prematurely with error: %v", exitErr) + if curState, err := p.swapState(StateStarting, StateFailed); err != nil { + return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState) + } else { + return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error()) + } + } else { + p.proxyLogger.Warnf("upstream command exited prematurely with no error") + if curState, err := p.swapState(StateStarting, StateFailed); err != nil { + return fmt.Errorf("upstream command exited prematurely with no error AND state swap failed: %v, current state: %v", err, curState) + } else { + return fmt.Errorf("upstream command exited prematurely with no error") + } + } default: if err := p.checkHealthEndpoint(healthURL); err == nil { p.proxyLogger.Infof("Health check passed on %s", healthURL) @@ -257,7 +285,6 @@ func (p *Process) start() error { p.inFlightRequests.Wait() if time.Since(p.lastRequestHandled) > maxDuration { - p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter) p.Stop() return @@ -276,6 +303,7 @@ func (p *Process) start() error { func (p *Process) Stop() { // wait for any inflight requests before proceeding p.inFlightRequests.Wait() + p.proxyLogger.Debugf("Stopping process [%s]", p.ID) // calling Stop() when state is invalid is a no-op if curState, err := p.swapState(StateReady, StateStopping); err != nil { @@ -311,11 +339,6 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) { sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL) defer cancelTimeout() - sigtermNormal := make(chan error, 1) - go func() { - sigtermNormal <- p.cmd.Wait() - }() - if p.cmd == nil || p.cmd.Process == nil { p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID) return @@ -329,7 +352,11 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) { case <-sigtermTimeout.Done(): p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID) p.cmd.Process.Kill() - case err := <-sigtermNormal: + case err := <-p.cmdWaitChan: + // Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK + // because if we make it here then the cmd has been successfully running and made it + // through the health check. There is a possibility that ithe cmd crashed after the health check + // succeeded but that's not a case llama-swap is handling for now. if err != nil { if errno, ok := err.(syscall.Errno); ok { p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno) diff --git a/proxy/process_test.go b/proxy/process_test.go index f192240..b516511 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -2,9 +2,9 @@ package proxy import ( "fmt" - "io" "net/http" "net/http/httptest" + "os" "sync" "testing" "time" @@ -13,16 +13,25 @@ import ( ) var ( - discardLogger = NewLogMonitorWriter(io.Discard) + debugLogger = NewLogMonitorWriter(os.Stdout) ) +func init() { + // flip to help with debugging tests + if false { + debugLogger.SetLogLevel(LevelDebug) + } else { + debugLogger.SetLogLevel(LevelError) + } +} + func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { expectedMessage := "testing91931" config := getTestSimpleResponderConfig(expectedMessage) // Create a process - process := NewProcess("test-process", 5, config, discardLogger, discardLogger) + process := NewProcess("test-process", 5, config, debugLogger, debugLogger) defer process.Stop() req := httptest.NewRequest("GET", "/test", nil) @@ -58,7 +67,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) { expectedMessage := "testing91931" config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("test-process", 5, config, discardLogger, discardLogger) + process := NewProcess("test-process", 5, config, debugLogger, debugLogger) defer process.Stop() var wg sync.WaitGroup @@ -86,7 +95,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) { CheckEndpoint: "/health", } - process := NewProcess("broken", 1, config, discardLogger, discardLogger) + process := NewProcess("broken", 1, config, debugLogger, debugLogger) req := httptest.NewRequest("GET", "/", nil) w := httptest.NewRecorder() @@ -111,7 +120,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) { config.UnloadAfter = 3 // seconds assert.Equal(t, 3, config.UnloadAfter) - process := NewProcess("ttl_test", 2, config, discardLogger, discardLogger) + process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger) defer process.Stop() // this should take 4 seconds @@ -153,7 +162,7 @@ func TestProcess_LowTTLValue(t *testing.T) { config.UnloadAfter = 1 // second assert.Equal(t, 1, config.UnloadAfter) - process := NewProcess("ttl", 2, config, discardLogger, discardLogger) + process := NewProcess("ttl", 2, config, debugLogger, debugLogger) defer process.Stop() for i := 0; i < 100; i++ { @@ -180,7 +189,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { expectedMessage := "12345" config := getTestSimpleResponderConfig(expectedMessage) - process := NewProcess("t", 10, config, discardLogger, discardLogger) + process := NewProcess("t", 10, config, debugLogger, debugLogger) defer process.Stop() results := map[string]string{ @@ -257,7 +266,7 @@ func TestProcess_SwapState(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), discardLogger, discardLogger) + p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger) p.state = test.currentState resultState, err := p.swapState(test.expectedState, test.newState) @@ -290,7 +299,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) { config.Proxy = "http://localhost:9998/test" healthCheckTTLSeconds := 30 - process := NewProcess("test-process", healthCheckTTLSeconds, config, discardLogger, discardLogger) + process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger) // make it a lot faster process.healthCheckLoopInterval = time.Second @@ -311,3 +320,23 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) { assert.ErrorContains(t, err, "health check interrupted due to shutdown") assert.Equal(t, StateShutdown, process.CurrentState()) } + +func TestProcess_ExitInterruptsHealthCheck(t *testing.T) { + if testing.Short() { + t.Skip("skipping Exit Interrupts Health Check test") + } + + // should run and exit but interrupt the long checkHealthTimeout + checkHealthTimeout := 5 + config := ModelConfig{ + Cmd: "sleep 1", + Proxy: "http://127.0.0.1:9913", + CheckEndpoint: "/health", + } + + process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger) + process.healthCheckLoopInterval = time.Second // make it faster + err := process.start() + assert.Equal(t, "upstream command exited prematurely with no error", err.Error()) + assert.Equal(t, process.CurrentState(), StateFailed) +} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index f436d70..da7b1a9 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -49,14 +49,19 @@ func New(config *Config) *ProxyManager { switch strings.ToLower(strings.TrimSpace(config.LogLevel)) { case "debug": proxyLogger.SetLogLevel(LevelDebug) + upstreamLogger.SetLogLevel(LevelDebug) case "info": proxyLogger.SetLogLevel(LevelInfo) + upstreamLogger.SetLogLevel(LevelInfo) case "warn": proxyLogger.SetLogLevel(LevelWarn) + upstreamLogger.SetLogLevel(LevelWarn) case "error": proxyLogger.SetLogLevel(LevelError) + upstreamLogger.SetLogLevel(LevelError) default: proxyLogger.SetLogLevel(LevelInfo) + upstreamLogger.SetLogLevel(LevelInfo) } pm := &ProxyManager{ diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 15938ed..2352f48 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -22,6 +22,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) { "model1": getTestSimpleResponderConfig("model1"), "model2": getTestSimpleResponderConfig("model2"), }, + LogLevel: "error", } proxy := New(config) @@ -62,6 +63,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) { Profiles: map[string][]string{ "test": {model1, model2}, }, + LogLevel: "error", } proxy := New(config) @@ -103,6 +105,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { "model2": getTestSimpleResponderConfig("model2"), "model3": getTestSimpleResponderConfig("model3"), }, + LogLevel: "error", } proxy := New(config) @@ -153,6 +156,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { "model2": getTestSimpleResponderConfig("model2"), "model3": getTestSimpleResponderConfig("model3"), }, + LogLevel: "error", } proxy := New(config) @@ -230,6 +234,7 @@ func TestProxyManager_ProfileNonMember(t *testing.T) { Profiles: map[string][]string{ "test": {model1}, }, + LogLevel: "error", } proxy := New(config) @@ -278,6 +283,7 @@ func TestProxyManager_Shutdown(t *testing.T) { "model2": model2Config, "model3": model3Config, }, + LogLevel: "error", } proxy := New(config) @@ -313,6 +319,7 @@ func TestProxyManager_Unload(t *testing.T) { Models: map[string]ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, + LogLevel: "error", } proxy := New(config) @@ -339,6 +346,7 @@ func TestProxyManager_StripProfileSlug(t *testing.T) { Models: map[string]ModelConfig{ "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), }, + LogLevel: "error", } proxy := New(config) @@ -365,6 +373,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { Profiles: map[string][]string{ "test": {"model1", "model2"}, }, + LogLevel: "error", } // Define a helper struct to parse the JSON response. @@ -472,6 +481,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { Models: map[string]ModelConfig{ "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), }, + LogLevel: "error", } proxy := New(config) @@ -580,6 +590,8 @@ func TestProxyManager_UseModelName(t *testing.T) { Models: map[string]ModelConfig{ "model1": modelConfig, }, + + LogLevel: "error", } proxy := New(config) @@ -647,7 +659,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) { Models: map[string]ModelConfig{ "model1": getTestSimpleResponderConfig("model1"), }, - LogRequests: true, + LogLevel: "error", } tests := []struct {