From d625ab8d92ba45096ae8164136a08a9732a91710 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Sat, 15 Mar 2025 17:14:03 -0700 Subject: [PATCH] Refactor process state management (#70) (#73) * add isValidStateTransition helper function * Replace Process.setState() with Process.swapState() * Refactor locking logic in Process --- proxy/process.go | 208 ++++++++++++++++++++---------------------- proxy/process_test.go | 47 +++++----- 2 files changed, 123 insertions(+), 132 deletions(-) diff --git a/proxy/process.go b/proxy/process.go index db3bb5d..73e56d0 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -30,11 +30,13 @@ const ( ) type Process struct { - ID string - config ModelConfig - cmd *exec.Cmd - logMonitor *LogMonitor - healthCheckTimeout int + ID string + config ModelConfig + cmd *exec.Cmd + logMonitor *LogMonitor + + healthCheckTimeout int + healthCheckLoopInterval time.Duration lastRequestHandled time.Time @@ -54,51 +56,57 @@ type Process struct { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process { ctx, cancel := context.WithCancel(context.Background()) return &Process{ - ID: ID, - config: config, - cmd: nil, - logMonitor: logMonitor, - healthCheckTimeout: healthCheckTimeout, - state: StateStopped, - shutdownCtx: ctx, - shutdownCancel: cancel, + ID: ID, + config: config, + cmd: nil, + logMonitor: logMonitor, + healthCheckTimeout: healthCheckTimeout, + healthCheckLoopInterval: 5 * time.Second, /* default, can not be set by user - used for testing */ + state: StateStopped, + shutdownCtx: ctx, + shutdownCancel: cancel, } } -func (p *Process) setState(newState ProcessState) error { - // enforce valid state transitions - invalidTransition := false - if p.state == StateStopped { - // stopped -> starting - if newState != StateStarting { - invalidTransition = true - } - } else if p.state == StateStarting { - // starting -> ready | failed | stopping - if newState != StateReady && newState != StateFailed && newState != StateStopping { - invalidTransition = true - } - } else if p.state == StateReady { - // ready -> stopping - if newState != StateStopping { - invalidTransition = true - } - } else if p.state == StateStopping { - // stopping -> stopped | shutdown - if newState != StateStopped && newState != StateShutdown { - invalidTransition = true - } - } else if p.state == StateFailed || p.state == StateShutdown { - invalidTransition = true +// custom error types for swapping state +var ( + ErrExpectedStateMismatch = errors.New("expected state mismatch") + ErrInvalidStateTransition = errors.New("invalid state transition") +) + +// swapState performs a compare and swap of the state atomically. It returns the current state +// and an error if the swap failed. +func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, error) { + p.stateMutex.Lock() + defer p.stateMutex.Unlock() + + if p.state != expectedState { + return p.state, ErrExpectedStateMismatch } - if invalidTransition { - //panic(fmt.Sprintf("Invalid state transition from %s to %s", p.state, newState)) - return fmt.Errorf("invalid state transition from %s to %s", p.state, newState) + if !isValidTransition(p.state, newState) { + return p.state, ErrInvalidStateTransition } p.state = newState - return nil + return p.state, nil +} + +// Helper function to encapsulate transition rules +func isValidTransition(from, to ProcessState) bool { + switch from { + case StateStopped: + return to == StateStarting + case StateStarting: + return to == StateReady || to == StateFailed || to == StateStopping + case StateReady: + return to == StateStopping + case StateStopping: + return to == StateStopped || to == StateShutdown + case StateFailed, StateShutdown: + return false // No transitions allowed from these states + } + return false } func (p *Process) CurrentState() ProcessState { @@ -116,56 +124,33 @@ func (p *Process) start() error { return fmt.Errorf("can not start(), upstream proxy missing") } - // multiple start() calls will wait for the one that is actually starting to - // complete before proceeding. - // =========== - curState := p.CurrentState() - - if curState == StateReady { - return nil - } - - if curState == StateStarting { - p.waitStarting.Wait() - - if state := p.CurrentState(); state != StateReady { - return fmt.Errorf("start() failed current state: %v", state) - } - - return nil - } - // =========== - - // There is the possibility of a hard to replicate race condition where - // curState *WAS* StateStopped but by the time we get to the p.stateMutex.Lock() - // below, it's value has changed! - - p.stateMutex.Lock() - defer p.stateMutex.Unlock() - - // with the exclusive lock, check if p.state is StateStopped, which is the only valid state - // to transition from to StateReady - - if p.state != StateStopped { - if p.state == StateReady { - return nil - } else { - return fmt.Errorf("start() can not proceed expected StateReady but process is in %v", p.state) - } - } - - if err := p.setState(StateStarting); err != nil { - return err - } - - p.waitStarting.Add(1) - defer p.waitStarting.Done() - args, err := p.config.SanitizedCommand() if err != nil { return fmt.Errorf("unable to get sanitized command: %v", err) } + if curState, err := p.swapState(StateStopped, StateStarting); err != nil { + if err == ErrExpectedStateMismatch { + // already starting, just wait for it to complete and expect + // it to be be in the Ready start after. If not, return an error + if curState == StateStarting { + p.waitStarting.Wait() + if state := p.CurrentState(); state == StateReady { + return nil + } else { + return fmt.Errorf("process was already starting but wound up in state %v", state) + } + } else { + return fmt.Errorf("processes was in state %v when start() was called", curState) + } + } else { + return fmt.Errorf("failed to set Process state to starting: current state: %v, error: %v", curState, err) + } + } + + p.waitStarting.Add(1) + defer p.waitStarting.Done() + p.cmd = exec.Command(args[0], args[1:]...) p.cmd.Stdout = p.logMonitor p.cmd.Stderr = p.logMonitor @@ -173,8 +158,14 @@ func (p *Process) start() error { err = p.cmd.Start() + // Set process state to failed if err != nil { - p.setState(StateFailed) + if curState, swapErr := p.swapState(StateStarting, StateFailed); err != nil { + return fmt.Errorf( + "failed to start command and state swap failed. command error: %v, current state: %v, state swap error: %v", + err, curState, swapErr, + ) + } return fmt.Errorf("start() failed: %v", err) } @@ -209,13 +200,16 @@ func (p *Process) start() error { ) defer cancelHealthCheck() - // Health check loop loop: + // Ready Check loop for { select { case <-checkDeadline.Done(): - p.setState(StateFailed) - return fmt.Errorf("health check failed after %vs", maxDuration.Seconds()) + if curState, err := p.swapState(StateStarting, StateFailed); err != nil { + return fmt.Errorf("health check timed out after %vs AND state swap failed: %v, current state: %v", maxDuration.Seconds(), err, curState) + } else { + return fmt.Errorf("health check timed out after %vs", maxDuration.Seconds()) + } case <-p.shutdownCtx.Done(): return errors.New("health check interrupted due to shutdown") default: @@ -233,7 +227,7 @@ func (p *Process) start() error { } } - <-time.After(5 * time.Second) + <-time.After(p.healthCheckLoopInterval) } } @@ -244,7 +238,7 @@ func (p *Process) start() error { maxDuration := time.Duration(p.config.UnloadAfter) * time.Second for range time.Tick(time.Second) { - if p.state != StateReady { + if p.CurrentState() != StateReady { return } @@ -260,26 +254,28 @@ func (p *Process) start() error { }() } - return p.setState(StateReady) + if curState, err := p.swapState(StateStarting, StateReady); err != nil { + return fmt.Errorf("failed to set Process state to ready: current state: %v, error: %v", curState, err) + } else { + return nil + } } func (p *Process) Stop() { // wait for any inflight requests before proceeding p.inFlightRequests.Wait() - p.stateMutex.Lock() - defer p.stateMutex.Unlock() // calling Stop() when state is invalid is a no-op - if err := p.setState(StateStopping); err != nil { - fmt.Fprintf(p.logMonitor, "!!! Info - Stop() err: %v\n", err) + if curState, err := p.swapState(StateReady, StateStopping); err != nil { + fmt.Fprintf(p.logMonitor, "!!! Info - Stop() Ready -> StateStopping err: %v, current state: %v\n", err, curState) return } // stop the process with a graceful exit timeout p.stopCommand(5 * time.Second) - if err := p.setState(StateStopped); err != nil { - panic(fmt.Sprintf("Stop() failed to set state to stopped: %v", err)) + if curState, err := p.swapState(StateStopping, StateStopped); err != nil { + fmt.Fprintf(p.logMonitor, "!!! Info - Stop() StateStopping -> StateStopped err: %v, current state: %v\n", err, curState) } } @@ -287,19 +283,9 @@ func (p *Process) Stop() { // of time for any inflight requests to complete before shutting down. If the Process // is in the state of starting, it will cancel it and shut it down func (p *Process) Shutdown() { - // cancel anything that can be interrupted by a shutdown (ie: healthcheck) p.shutdownCancel() - - p.stateMutex.Lock() - defer p.stateMutex.Unlock() - p.setState(StateStopping) - - // 5 seconds to stop the process p.stopCommand(5 * time.Second) - if err := p.setState(StateShutdown); err != nil { - fmt.Printf("!!! Shutdown() failed to set state to shutdown: %v", err) - } - p.setState(StateShutdown) + p.state = StateShutdown } // stopCommand will send a SIGTERM to the process and wait for it to exit. diff --git a/proxy/process_test.go b/proxy/process_test.go index 02ba913..a42a2aa 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -225,30 +225,32 @@ func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { } } -func TestSetState(t *testing.T) { +func TestProcess_SwapState(t *testing.T) { tests := []struct { name string currentState ProcessState + expectedState ProcessState newState ProcessState expectedError error expectedResult ProcessState }{ - {"Stopped to Starting", StateStopped, StateStarting, nil, StateStarting}, - {"Starting to Ready", StateStarting, StateReady, nil, StateReady}, - {"Starting to Failed", StateStarting, StateFailed, nil, StateFailed}, - {"Starting to Stopping", StateStarting, StateStopping, nil, StateStopping}, - {"Ready to Stopping", StateReady, StateStopping, nil, StateStopping}, - {"Stopping to Stopped", StateStopping, StateStopped, nil, StateStopped}, - {"Stopping to Shutdown", StateStopping, StateShutdown, nil, StateShutdown}, - {"Stopped to Ready", StateStopped, StateReady, fmt.Errorf("invalid state transition from stopped to ready"), StateStopped}, - {"Starting to Stopped", StateStarting, StateStopped, fmt.Errorf("invalid state transition from starting to stopped"), StateStarting}, - {"Ready to Starting", StateReady, StateStarting, fmt.Errorf("invalid state transition from ready to starting"), StateReady}, - {"Ready to Failed", StateReady, StateFailed, fmt.Errorf("invalid state transition from ready to failed"), StateReady}, - {"Stopping to Ready", StateStopping, StateReady, fmt.Errorf("invalid state transition from stopping to ready"), StateStopping}, - {"Failed to Stopped", StateFailed, StateStopped, fmt.Errorf("invalid state transition from failed to stopped"), StateFailed}, - {"Failed to Starting", StateFailed, StateStarting, fmt.Errorf("invalid state transition from failed to starting"), StateFailed}, - {"Shutdown to Stopped", StateShutdown, StateStopped, fmt.Errorf("invalid state transition from shutdown to stopped"), StateShutdown}, - {"Shutdown to Starting", StateShutdown, StateStarting, fmt.Errorf("invalid state transition from shutdown to starting"), StateShutdown}, + {"Stopped to Starting", StateStopped, StateStopped, StateStarting, nil, StateStarting}, + {"Starting to Ready", StateStarting, StateStarting, StateReady, nil, StateReady}, + {"Starting to Failed", StateStarting, StateStarting, StateFailed, nil, StateFailed}, + {"Starting to Stopping", StateStarting, StateStarting, StateStopping, nil, StateStopping}, + {"Ready to Stopping", StateReady, StateReady, StateStopping, nil, StateStopping}, + {"Stopping to Stopped", StateStopping, StateStopping, StateStopped, nil, StateStopped}, + {"Stopping to Shutdown", StateStopping, StateStopping, StateShutdown, nil, StateShutdown}, + {"Stopped to Ready", StateStopped, StateStopped, StateReady, ErrInvalidStateTransition, StateStopped}, + {"Starting to Stopped", StateStarting, StateStarting, StateStopped, ErrInvalidStateTransition, StateStarting}, + {"Ready to Starting", StateReady, StateReady, StateStarting, ErrInvalidStateTransition, StateReady}, + {"Ready to Failed", StateReady, StateReady, StateFailed, ErrInvalidStateTransition, StateReady}, + {"Stopping to Ready", StateStopping, StateStopping, StateReady, ErrInvalidStateTransition, StateStopping}, + {"Failed to Stopped", StateFailed, StateFailed, StateStopped, ErrInvalidStateTransition, StateFailed}, + {"Failed to Starting", StateFailed, StateFailed, StateStarting, ErrInvalidStateTransition, StateFailed}, + {"Shutdown to Stopped", StateShutdown, StateShutdown, StateStopped, ErrInvalidStateTransition, StateShutdown}, + {"Shutdown to Starting", StateShutdown, StateShutdown, StateStarting, ErrInvalidStateTransition, StateShutdown}, + {"Expected state mismatch", StateStopped, StateStarting, StateStarting, ErrExpectedStateMismatch, StateStopped}, } for _, test := range tests { @@ -257,7 +259,7 @@ func TestSetState(t *testing.T) { state: test.currentState, } - err := p.setState(test.newState) + resultState, err := p.swapState(test.expectedState, test.newState) if err != nil && test.expectedError == nil { t.Errorf("Unexpected error: %v", err) } else if err == nil && test.expectedError != nil { @@ -268,8 +270,8 @@ func TestSetState(t *testing.T) { } } - if p.state != test.expectedResult { - t.Errorf("Expected state: %v, got: %v", test.expectedResult, p.state) + if resultState != test.expectedResult { + t.Errorf("Expected state: %v, got: %v", test.expectedResult, resultState) } }) } @@ -290,11 +292,14 @@ func TestProcess_ShutdownInterruptsHealthCheck(t *testing.T) { healthCheckTTLSeconds := 30 process := NewProcess("test-process", healthCheckTTLSeconds, config, logMonitor) + // make it a lot faster + process.healthCheckLoopInterval = time.Second + // start a goroutine to simulate a shutdown var wg sync.WaitGroup go func() { defer wg.Done() - <-time.After(time.Second * 2) + <-time.After(time.Millisecond * 500) process.Shutdown() }() wg.Add(1)