Fix race conditions in proxy.Process (#349)

- Fix data races found in proxy.Process by go's race detector. 
- Add data race detection to the CI tests. 

Fixes #348
This commit is contained in:
Benson Wong
2025-10-13 16:42:49 -07:00
committed by GitHub
parent 539278343b
commit caf9e98b1e
3 changed files with 79 additions and 20 deletions

View File

@@ -33,7 +33,7 @@ test: proxy/ui_dist/placeholder.txt
# for CI - full test (takes longer) # for CI - full test (takes longer)
test-all: proxy/ui_dist/placeholder.txt test-all: proxy/ui_dist/placeholder.txt
go test -count=1 ./proxy/... go test -race -count=1 ./proxy/...
ui/node_modules: ui/node_modules:
cd ui && npm install cd ui && npm install

View File

@@ -12,6 +12,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"syscall" "syscall"
"time" "time"
@@ -44,6 +45,7 @@ type Process struct {
cmd *exec.Cmd cmd *exec.Cmd
// PR #155 called to cancel the upstream process // PR #155 called to cancel the upstream process
cmdMutex sync.RWMutex
cancelUpstream context.CancelFunc cancelUpstream context.CancelFunc
// closed when command exits // closed when command exits
@@ -55,12 +57,14 @@ type Process struct {
healthCheckTimeout int healthCheckTimeout int
healthCheckLoopInterval time.Duration healthCheckLoopInterval time.Duration
lastRequestHandled time.Time lastRequestHandledMutex sync.RWMutex
lastRequestHandled time.Time
stateMutex sync.RWMutex stateMutex sync.RWMutex
state ProcessState state ProcessState
inFlightRequests sync.WaitGroup inFlightRequests sync.WaitGroup
inFlightRequestsCount atomic.Int32
// used to block on multiple start() calls // used to block on multiple start() calls
waitStarting sync.WaitGroup waitStarting sync.WaitGroup
@@ -107,6 +111,20 @@ func (p *Process) LogMonitor() *LogMonitor {
return p.processLogger return p.processLogger
} }
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
func (p *Process) setLastRequestHandled(t time.Time) {
p.lastRequestHandledMutex.Lock()
defer p.lastRequestHandledMutex.Unlock()
p.lastRequestHandled = t
}
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
func (p *Process) getLastRequestHandled() time.Time {
p.lastRequestHandledMutex.RLock()
defer p.lastRequestHandledMutex.RUnlock()
return p.lastRequestHandled
}
// custom error types for swapping state // custom error types for swapping state
var ( var (
ErrExpectedStateMismatch = errors.New("expected state mismatch") ErrExpectedStateMismatch = errors.New("expected state mismatch")
@@ -130,6 +148,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
} }
p.state = newState p.state = newState
// Atomically increment waitStarting when entering StateStarting
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
if newState == StateStarting {
p.waitStarting.Add(1)
}
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState) p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState}) event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
return p.state, nil return p.state, nil
@@ -158,6 +183,15 @@ func (p *Process) CurrentState() ProcessState {
return p.state return p.state
} }
// forceState forces the process state to the new state with mutex protection.
// This should only be used in exceptional cases where the normal state transition
// validation via swapState() cannot be used.
func (p *Process) forceState(newState ProcessState) {
p.stateMutex.Lock()
defer p.stateMutex.Unlock()
p.state = newState
}
// start starts the upstream command, checks the health endpoint, and sets the state to Ready // start starts the upstream command, checks the health endpoint, and sets the state to Ready
// it is a private method because starting is automatic but stopping can be called // it is a private method because starting is automatic but stopping can be called
// at any time. // at any time.
@@ -191,7 +225,7 @@ func (p *Process) start() error {
} }
} }
p.waitStarting.Add(1) // waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
defer p.waitStarting.Done() defer p.waitStarting.Done()
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background()) cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
@@ -201,8 +235,11 @@ func (p *Process) start() error {
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...) p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
p.cmd.Cancel = p.cmdStopUpstreamProcess p.cmd.Cancel = p.cmdStopUpstreamProcess
p.cmd.WaitDelay = p.gracefulStopTimeout p.cmd.WaitDelay = p.gracefulStopTimeout
p.cmdMutex.Lock()
p.cancelUpstream = ctxCancelUpstream p.cancelUpstream = ctxCancelUpstream
p.cmdWaitChan = make(chan struct{}) p.cmdWaitChan = make(chan struct{})
p.cmdMutex.Unlock()
p.failedStartCount++ // this will be reset to zero when the process has successfully started p.failedStartCount++ // this will be reset to zero when the process has successfully started
@@ -212,7 +249,7 @@ func (p *Process) start() error {
// Set process state to failed // Set process state to failed
if err != nil { if err != nil {
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil { if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
p.state = StateStopped // force it into a stopped state p.forceState(StateStopped) // force it into a stopped state
return fmt.Errorf( return fmt.Errorf(
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v", "failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
strings.Join(args, " "), err, curState, swapErr, strings.Join(args, " "), err, curState, swapErr,
@@ -285,10 +322,12 @@ func (p *Process) start() error {
return return
} }
// wait for all inflight requests to complete and ticker // skip the TTL check if there are inflight requests
p.inFlightRequests.Wait() if p.inFlightRequestsCount.Load() != 0 {
continue
}
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.getLastRequestHandled()) > maxDuration {
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter) p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return return
@@ -344,7 +383,7 @@ func (p *Process) Shutdown() {
p.stopCommand() p.stopCommand()
// just force it to this state since there is no recovery from shutdown // just force it to this state since there is no recovery from shutdown
p.state = StateShutdown p.forceState(StateShutdown)
} }
// stopCommand will send a SIGTERM to the process and wait for it to exit. // stopCommand will send a SIGTERM to the process and wait for it to exit.
@@ -355,13 +394,18 @@ func (p *Process) stopCommand() {
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime)) p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
}() }()
if p.cancelUpstream == nil { p.cmdMutex.RLock()
cancelUpstream := p.cancelUpstream
cmdWaitChan := p.cmdWaitChan
p.cmdMutex.RUnlock()
if cancelUpstream == nil {
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID) p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
return return
} }
p.cancelUpstream() cancelUpstream()
<-p.cmdWaitChan <-cmdWaitChan
} }
func (p *Process) checkHealthEndpoint(healthURL string) error { func (p *Process) checkHealthEndpoint(healthURL string) error {
@@ -418,8 +462,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
} }
p.inFlightRequests.Add(1) p.inFlightRequests.Add(1)
p.inFlightRequestsCount.Add(1)
defer func() { defer func() {
p.lastRequestHandled = time.Now() p.setLastRequestHandled(time.Now())
p.inFlightRequestsCount.Add(-1)
p.inFlightRequests.Done() p.inFlightRequests.Done()
}() }()
@@ -519,13 +565,16 @@ func (p *Process) waitForCmd() {
case StateStopping: case StateStopping:
if curState, err := p.swapState(StateStopping, StateStopped); err != nil { if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err) p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
p.state = StateStopped p.forceState(StateStopped)
} }
default: default:
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState) p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
p.state = StateStopped // force it to be in this state p.forceState(StateStopped) // force it to be in this state
} }
p.cmdMutex.Lock()
close(p.cmdWaitChan) close(p.cmdWaitChan)
p.cmdMutex.Unlock()
} }
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully // cmdStopUpstreamProcess attemps to stop the upstream process gracefully

View File

@@ -1075,18 +1075,28 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
for _, endpoint := range endpoints { for _, endpoint := range endpoints {
t.Run(endpoint, func(t *testing.T) { t.Run(endpoint, func(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
req := httptest.NewRequest("GET", endpoint, nil) req := httptest.NewRequest("GET", endpoint, nil)
req = req.WithContext(ctx) req = req.WithContext(ctx)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
// We don't need the handler to fully complete, just to set the headers // Run handler in goroutine and wait for context timeout
// so run it in a goroutine and check the headers after a short delay done := make(chan struct{})
go proxy.ServeHTTP(rec, req) go func() {
time.Sleep(10 * time.Millisecond) // give it time to start and write headers defer close(done)
proxy.ServeHTTP(rec, req)
}()
// Wait for either the handler to complete or context to timeout
<-ctx.Done()
// At this point, the handler has either finished or been cancelled
// Wait for the goroutine to fully exit before reading
<-done
// Now it's safe to read from rec - no more concurrent writes
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering")) assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
}) })