Improve shutdown behaviour (#47) (#49)

Introduce `Process.Shutdown()` and `ProxyManager.Shutdown()`. These two function required a lot of internal process state management refactoring. A key benefit is that `Process.start()` is now interruptable. When `Shutdown()` is called it will break the long health check loop. 

State management within Process is also improved. Added `starting`, `stopping` and `shutdown` states. Additionally, introduced a simple finite state machine to manage transitions.
This commit is contained in:
Benson Wong
2025-02-05 17:19:59 -08:00
committed by GitHub
parent 85cd74a51c
commit 09bdd86b54
5 changed files with 398 additions and 98 deletions

View File

@@ -2,6 +2,7 @@ package proxy
import (
"context"
"errors"
"fmt"
"io"
"net/http"
@@ -16,14 +17,19 @@ import (
type ProcessState string
const (
StateStopped ProcessState = ProcessState("stopped")
StateReady ProcessState = ProcessState("ready")
StateFailed ProcessState = ProcessState("failed")
StateStopped ProcessState = ProcessState("stopped")
StateStarting ProcessState = ProcessState("starting")
StateReady ProcessState = ProcessState("ready")
StateStopping ProcessState = ProcessState("stopping")
// failed a health check on start and will not be recovered
StateFailed ProcessState = ProcessState("failed")
// process is shutdown and will not be restarted
StateShutdown ProcessState = ProcessState("shutdown")
)
type Process struct {
sync.Mutex
ID string
config ModelConfig
cmd *exec.Cmd
@@ -36,9 +42,17 @@ type Process struct {
state ProcessState
inFlightRequests sync.WaitGroup
// used to block on multiple start() calls
waitStarting sync.WaitGroup
// for managing shutdown state
shutdownCtx context.Context
shutdownCancel context.CancelFunc
}
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
ctx, cancel := context.WithCancel(context.Background())
return &Process{
ID: ID,
config: config,
@@ -46,22 +60,88 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito
logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout,
state: StateStopped,
shutdownCtx: ctx,
shutdownCancel: cancel,
}
}
// start the process and returns when it is ready
func (p *Process) setState(newState ProcessState) error {
// enforce valid state transitions
invalidTransition := false
if p.state == StateStopped {
// stopped -> starting
if newState != StateStarting {
invalidTransition = true
}
} else if p.state == StateStarting {
// starting -> ready | failed | stopping
if newState != StateReady && newState != StateFailed && newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateReady {
// ready -> stopping
if newState != StateStopping {
invalidTransition = true
}
} else if p.state == StateStopping {
// stopping -> stopped | shutdown
if newState != StateStopped && newState != StateShutdown {
invalidTransition = true
}
} else if p.state == StateFailed || p.state == StateShutdown {
invalidTransition = true
}
if invalidTransition {
//panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState))
return fmt.Errorf("invalid state transition from %s to %s", p.state, newState)
}
p.state = newState
return nil
}
func (p *Process) CurrentState() ProcessState {
p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
}
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called
// at any time.
func (p *Process) start() error {
if p.config.Proxy == "" {
return fmt.Errorf("can not start(), upstream proxy missing")
}
// wait for the other start() to complete
curState := p.CurrentState()
if curState == StateReady {
return nil
}
if curState == StateStarting {
p.waitStarting.Wait()
if state := p.CurrentState(); state != StateReady {
return fmt.Errorf("start() failed current state: %v", state)
}
return nil
}
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state == StateReady {
return nil
if err := p.setState(StateStarting); err != nil {
return err
}
if p.state == StateFailed {
return fmt.Errorf("process is in a failed state and can not be restarted")
}
p.waitStarting.Add(1)
defer p.waitStarting.Done()
args, err := p.config.SanitizedCommand()
if err != nil {
@@ -76,7 +156,8 @@ func (p *Process) start() error {
err = p.cmd.Start()
if err != nil {
return err
p.setState(StateFailed)
return fmt.Errorf("start() failed: %v", err)
}
// One of three things can happen at this stage:
@@ -87,9 +168,55 @@ func (p *Process) start() error {
// only in the third case will the process be considered Ready to accept
<-time.After(250 * time.Millisecond) // give process a bit of time to start
if err := p.checkHealthEndpoint(); err != nil {
p.state = StateFailed
return err
checkStartTime := time.Now()
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
// a "none" means don't check for health ... I could have picked a better word :facepalm:
if checkEndpoint != "none" {
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
proxyTo := p.config.Proxy
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil {
return fmt.Errorf("failed to create health check URL proxy=%s and checkEndpoint=%s", proxyTo, checkEndpoint)
}
checkDeadline, cancelHealthCheck := context.WithDeadline(
context.Background(),
checkStartTime.Add(maxDuration),
)
defer cancelHealthCheck()
// Health check loop
loop:
for {
select {
case <-checkDeadline.Done():
p.setState(StateFailed)
return fmt.Errorf("health check failed after %vs", maxDuration.Seconds())
case <-p.shutdownCtx.Done():
return errors.New("health check interrupted due to shutdown")
default:
if err := p.checkHealthEndpoint(healthURL); err == nil {
cancelHealthCheck()
break loop
} else {
if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime)
fmt.Fprintf(p.logMonitor, "!!! Connection refused on %s, ttl %.0fs\n", healthURL, ttl.Seconds())
} else {
fmt.Fprintf(p.logMonitor, "!!! Health check error: %v\n", err)
}
}
}
<-time.After(time.Second)
}
}
if p.config.UnloadAfter > 0 {
@@ -115,37 +242,63 @@ func (p *Process) start() error {
}()
}
p.state = StateReady
return nil
return p.setState(StateReady)
}
func (p *Process) Stop() {
// wait for any inflight requests before proceeding
p.inFlightRequests.Wait()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state != StateReady {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() called but Process State is not READY\n")
// calling Stop() when state is invalid is a no-op
if err := p.setState(StateStopping); err != nil {
fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err)
return
}
if p.cmd == nil || p.cmd.Process == nil {
// this situation should never happen... but if it does just update the state
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.\n")
p.state = StateStopped
return
}
// stop the process with a graceful exit timeout
p.stopCommand(5 * time.Second)
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := p.setState(StateStopped); err != nil {
panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err))
}
}
// Shutdown is called when llama-swap is shutting down. It will give a little bit
// of time for any inflight requests to complete before shutting down. If the Process
// is in the state of starting, it will cancel it and shut it down
func (p *Process) Shutdown() {
// cancel anything that can be interrupted by a shutdown (ie: healthcheck)
p.shutdownCancel()
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.setState(StateStopping)
// 5 seconds to stop the process
p.stopCommand(5 * time.Second)
if err := p.setState(StateShutdown); err != nil {
fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err)
}
p.setState(StateShutdown)
}
// stopCommand will send a SIGTERM to the process and wait for it to exit.
// If it does not exit within 5 seconds, it will send a SIGKILL.
func (p *Process) stopCommand(sigtermTTL time.Duration) {
sigtermTimeout, cancelTimeout := context.WithTimeout(context.Background(), sigtermTTL)
defer cancelTimeout()
sigtermNormal := make(chan error, 1)
go func() {
sigtermNormal <- p.cmd.Wait()
}()
if p.cmd == nil || p.cmd.Process == nil {
panic("this should not happen, cmd or cmd.Process is nil")
}
p.cmd.Process.Signal(syscall.SIGTERM)
select {
@@ -170,94 +323,53 @@ func (p *Process) Stop() {
}
}
}
p.state = StateStopped
}
func (p *Process) CurrentState() ProcessState {
p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
}
func (p *Process) checkHealthEndpoint(healthURL string) error {
func (p *Process) checkHealthEndpoint() error {
if p.config.Proxy == "" {
return fmt.Errorf("no upstream available to check /health")
client := &http.Client{
Timeout: 500 * time.Millisecond,
}
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
if checkEndpoint == "none" {
return nil
}
// keep default behaviour
if checkEndpoint == "" {
checkEndpoint = "/health"
}
proxyTo := p.config.Proxy
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
return err
}
client := &http.Client{}
startTime := time.Now()
for {
req, err := http.NewRequest("GET", healthURL, nil)
if err != nil {
return err
}
resp, err := client.Do(req)
ttl := (maxDuration - time.Since(startTime)).Seconds()
if err != nil {
// wait a bit longer for TCP connection issues
if strings.Contains(err.Error(), "connection refused") {
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
time.Sleep(5 * time.Second)
} else {
time.Sleep(time.Second)
}
if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
continue
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusOK {
return nil
}
if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL)
}
time.Sleep(time.Second)
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
// got a response but it was not an OK
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("status code: %d", resp.StatusCode)
}
return nil
}
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
p.inFlightRequests.Add(1)
// prevent new requests from being made while stopping or irrecoverable
currentState := p.CurrentState()
if currentState == StateFailed || currentState == StateShutdown || currentState == StateStopping {
http.Error(w, fmt.Sprintf("Process can not ProxyRequest, state is %s", currentState), http.StatusServiceUnavailable)
return
}
p.inFlightRequests.Add(1)
defer func() {
p.lastRequestHandled = time.Now()
p.inFlightRequests.Done()
}()
// start the process on demand
if p.CurrentState() != StateReady {
if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusInternalServerError)
http.Error(w, errstr, http.StatusBadGateway)
return
}
}