Make checkHealthTimeout Interruptable during startup (#102)

interrupt and exit Process.start() early if the upstream process exits prematurely or unexpectedly.
This commit is contained in:
Benson Wong
2025-04-24 14:39:33 -07:00
committed by GitHub
parent 8404244fab
commit 5fad24c16f
4 changed files with 93 additions and 20 deletions

View File

@@ -34,6 +34,9 @@ type Process struct {
config ModelConfig
cmd *exec.Cmd
// for p.cmd.Wait() select { ... }
cmdWaitChan chan error
processLogger *LogMonitor
proxyLogger *LogMonitor
@@ -61,6 +64,7 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, processLo
ID: ID,
config: config,
cmd: nil,
cmdWaitChan: make(chan error, 1),
processLogger: processLogger,
proxyLogger: proxyLogger,
healthCheckTimeout: healthCheckTimeout,
@@ -89,16 +93,17 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
defer p.stateMutex.Unlock()
if p.state != expectedState {
p.proxyLogger.Warnf("swapState() Unexpected current state %s, expected %s", p.state, expectedState)
return p.state, ErrExpectedStateMismatch
}
if !isValidTransition(p.state, newState) {
p.proxyLogger.Warnf("Invalid state transition from %s to %s", p.state, newState)
p.proxyLogger.Warnf("swapState() Invalid state transition from %s to %s", p.state, newState)
return p.state, ErrInvalidStateTransition
}
p.proxyLogger.Debugf("State transition from %s to %s", expectedState, newState)
p.state = newState
p.proxyLogger.Debugf("swapState() State transitioned from %s to %s", expectedState, newState)
return p.state, nil
}
@@ -179,6 +184,13 @@ func (p *Process) start() error {
return fmt.Errorf("start() failed: %v", err)
}
// Capture the exit error for later signaling
go func() {
exitErr := p.cmd.Wait()
p.proxyLogger.Debugf("cmd.Wait() returned for [%s] error: %v", p.ID, exitErr)
p.cmdWaitChan <- exitErr
}()
// One of three things can happen at this stage:
// 1. The command exits unexpectedly
// 2. The health check fails
@@ -222,6 +234,22 @@ func (p *Process) start() error {
}
case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown")
case exitErr := <-p.cmdWaitChan:
if exitErr != nil {
p.proxyLogger.Warnf("upstream command exited prematurely with error: %v", exitErr)
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited unexpectedly: %s AND state swap failed: %v, current state: %v", exitErr.Error(), err, curState)
} else {
return fmt.Errorf("upstream command exited unexpectedly: %s", exitErr.Error())
}
} else {
p.proxyLogger.Warnf("upstream command exited prematurely with no error")
if curState, err := p.swapState(StateStarting, StateFailed); err != nil {
return fmt.Errorf("upstream command exited prematurely with no error AND state swap failed: %v, current state: %v", err, curState)
} else {
return fmt.Errorf("upstream command exited prematurely with no error")
}
}
default:
if err := p.checkHealthEndpoint(healthURL); err == nil {
p.proxyLogger.Infof("Health check passed on %s", healthURL)
@@ -257,7 +285,6 @@ func (p *Process) start() error {
p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration {
p.proxyLogger.Infof("Unloading model %s, TTL of %ds reached.", p.ID, p.config.UnloadAfter)
p.Stop()
return
@@ -276,6 +303,7 @@ func (p *Process) start() error {
func (p *Process) Stop() {
// wait for any inflight requests before proceeding
p.inFlightRequests.Wait()
p.proxyLogger.Debugf("Stopping process [%s]", p.ID)
// calling Stop() when state is invalid is a no-op
if curState, err := p.swapState(StateReady, StateStopping); err != nil {
@@ -311,11 +339,6 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil {
p.proxyLogger.Warnf("Process [%s] cmd or cmd.Process is nil", p.ID)
return
@@ -329,7 +352,11 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
case <-sigtermTimeout.Done():
p.proxyLogger.Infof("Process [%s] timed out waiting to stop, sending KILL signal", p.ID)
p.cmd.Process.Kill()
case err := <-sigtermNormal:
case err := <-p.cmdWaitChan:
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
// because if we make it here then the cmd has been successfully running and made it
// through the health check. There is a possibility that ithe cmd crashed after the health check
// succeeded but that's not a case llama-swap is handling for now.
if err != nil {
if errno, ok := err.(syscall.Errno); ok {
p.proxyLogger.Errorf("Process [%s] errno >> %v", p.ID, errno)

View File

@@ -2,9 +2,9 @@ package proxy
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"time"
@@ -13,16 +13,25 @@ import (
)
var (
discardLogger = NewLogMonitorWriter(io.Discard)
debugLogger = NewLogMonitorWriter(os.Stdout)
)
func init() {
// flip to help with debugging tests
if false {
debugLogger.SetLogLevel(LevelDebug)
} else {
debugLogger.SetLogLevel(LevelError)
}
}
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
// Create a process
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil)
@@ -58,7 +67,7 @@ func TestProcess_WaitOnMultipleStarts(t *testing.T) {
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("test-process", 5, config, discardLogger, discardLogger)
process := NewProcess("test-process", 5, config, debugLogger, debugLogger)
defer process.Stop()
var wg sync.WaitGroup
@@ -86,7 +95,7 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
CheckEndpoint: "/health",
}
process := NewProcess("broken", 1, config, discardLogger, discardLogger)
process := NewProcess("broken", 1, config, debugLogger, debugLogger)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
@@ -111,7 +120,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl_test", 2, config, discardLogger, discardLogger)
process := NewProcess("ttl_test", 2, config, debugLogger, debugLogger)
defer process.Stop()
// this should take 4 seconds
@@ -153,7 +162,7 @@ func TestProcess_LowTTLValue(t *testing.T) {
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, discardLogger, discardLogger)
process := NewProcess("ttl", 2, config, debugLogger, debugLogger)
defer process.Stop()
for i := 0; i < 100; i++ {
@@ -180,7 +189,7 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, discardLogger, discardLogger)
process := NewProcess("t", 10, config, debugLogger, debugLogger)
defer process.Stop()
results := map[string]string{
@@ -257,7 +266,7 @@ func TestProcess_SwapState(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), discardLogger, discardLogger)
p := NewProcess("test", 10, getTestSimpleResponderConfig("test"), debugLogger, debugLogger)
p.state = test.currentState
resultState, err := p.swapState(test.expectedState, test.newState)
@@ -290,7 +299,7 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
config.Proxy = "http://localhost:9998/test"
healthCheckTTLSeconds := 30
process := NewProcess("test-process", healthCheckTTLSeconds, config, discardLogger, discardLogger)
process := NewProcess("test-process", healthCheckTTLSeconds, config, debugLogger, debugLogger)
// make it a lot faster
process.healthCheckLoopInterval = time.Second
@@ -311,3 +320,23 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) {
assert.ErrorContains(t, err, "health check interrupted due to shutdown")
assert.Equal(t, StateShutdown, process.CurrentState())
}
func TestProcess_ExitInterruptsHealthCheck(t *testing.T) {
if testing.Short() {
t.Skip("skipping Exit Interrupts Health Check test")
}
// should run and exit but interrupt the long checkHealthTimeout
checkHealthTimeout := 5
config := ModelConfig{
Cmd: "sleep 1",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("sleepy", checkHealthTimeout, config, debugLogger, debugLogger)
process.healthCheckLoopInterval = time.Second // make it faster
err := process.start()
assert.Equal(t, "upstream command exited prematurely with no error", err.Error())
assert.Equal(t, process.CurrentState(), StateFailed)
}

View File

@@ -49,14 +49,19 @@ func New(config *Config) *ProxyManager {
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
upstreamLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
upstreamLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{

View File

@@ -22,6 +22,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -62,6 +63,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
Profiles: map[string][]string{
"test": {model1, model2},
},
LogLevel: "error",
}
proxy := New(config)
@@ -103,6 +105,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -153,6 +156,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -230,6 +234,7 @@ func TestProxyManager_ProfileNonMember(t *testing.T) {
Profiles: map[string][]string{
"test": {model1},
},
LogLevel: "error",
}
proxy := New(config)
@@ -278,6 +283,7 @@ func TestProxyManager_Shutdown(t *testing.T) {
"model2": model2Config,
"model3": model3Config,
},
LogLevel: "error",
}
proxy := New(config)
@@ -313,6 +319,7 @@ func TestProxyManager_Unload(t *testing.T) {
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -339,6 +346,7 @@ func TestProxyManager_StripProfileSlug(t *testing.T) {
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -365,6 +373,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
Profiles: map[string][]string{
"test": {"model1", "model2"},
},
LogLevel: "error",
}
// Define a helper struct to parse the JSON response.
@@ -472,6 +481,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
LogLevel: "error",
}
proxy := New(config)
@@ -580,6 +590,8 @@ func TestProxyManager_UseModelName(t *testing.T) {
Models: map[string]ModelConfig{
"model1": modelConfig,
},
LogLevel: "error",
}
proxy := New(config)
@@ -647,7 +659,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogRequests: true,
LogLevel: "error",
}
tests := []struct {