From caf9e98b1eff8760096435f212b39f5c4a278dbe Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Mon, 13 Oct 2025 16:42:49 -0700 Subject: [PATCH] Fix race conditions in proxy.Process (#349) - Fix data races found in proxy.Process by go's race detector. - Add data race detection to the CI tests. Fixes #348 --- Makefile | 2 +- proxy/process.go | 77 +++++++++++++++++++++++++++++++------- proxy/proxymanager_test.go | 20 +++++++--- 3 files changed, 79 insertions(+), 20 deletions(-) diff --git a/Makefile b/Makefile index 2fc8e09..0c00b0a 100644 --- a/Makefile +++ b/Makefile @@ -33,7 +33,7 @@ test: proxy/ui_dist/placeholder.txt # for CI - full test (takes longer) test-all: proxy/ui_dist/placeholder.txt - go test -count=1 ./proxy/... + go test -race -count=1 ./proxy/... ui/node_modules: cd ui && npm install diff --git a/proxy/process.go b/proxy/process.go index 51a5bc6..971b000 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -12,6 +12,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -44,6 +45,7 @@ type Process struct { cmd *exec.Cmd // PR #155 called to cancel the upstream process + cmdMutex sync.RWMutex cancelUpstream context.CancelFunc // closed when command exits @@ -55,12 +57,14 @@ type Process struct { healthCheckTimeout int healthCheckLoopInterval time.Duration - lastRequestHandled time.Time + lastRequestHandledMutex sync.RWMutex + lastRequestHandled time.Time stateMutex sync.RWMutex state ProcessState - inFlightRequests sync.WaitGroup + inFlightRequests sync.WaitGroup + inFlightRequestsCount atomic.Int32 // used to block on multiple start() calls waitStarting sync.WaitGroup @@ -107,6 +111,20 @@ func (p *Process) LogMonitor() *LogMonitor { return p.processLogger } +// setLastRequestHandled sets the last request handled time in a thread-safe manner. +func (p *Process) setLastRequestHandled(t time.Time) { + p.lastRequestHandledMutex.Lock() + defer p.lastRequestHandledMutex.Unlock() + p.lastRequestHandled = t +} + +// getLastRequestHandled gets the last request handled time in a thread-safe manner. +func (p *Process) getLastRequestHandled() time.Time { + p.lastRequestHandledMutex.RLock() + defer p.lastRequestHandledMutex.RUnlock() + return p.lastRequestHandled +} + // custom error types for swapping state var ( ErrExpectedStateMismatch = errors.New("expected state mismatch") @@ -130,6 +148,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState, } p.state = newState + + // Atomically increment waitStarting when entering StateStarting + // This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented + if newState == StateStarting { + p.waitStarting.Add(1) + } + p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState) event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState}) return p.state, nil @@ -158,6 +183,15 @@ func (p *Process) CurrentState() ProcessState { return p.state } +// forceState forces the process state to the new state with mutex protection. +// This should only be used in exceptional cases where the normal state transition +// validation via swapState() cannot be used. +func (p *Process) forceState(newState ProcessState) { + p.stateMutex.Lock() + defer p.stateMutex.Unlock() + p.state = newState +} + // start starts the upstream command, checks the health endpoint, and sets the state to Ready // it is a private method because starting is automatic but stopping can be called // at any time. @@ -191,7 +225,7 @@ func (p *Process) start() error { } } - p.waitStarting.Add(1) + // waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting defer p.waitStarting.Done() cmdContext, ctxCancelUpstream := context.WithCancel(context.Background()) @@ -201,8 +235,11 @@ func (p *Process) start() error { p.cmd.Env = append(p.cmd.Environ(), p.config.Env...) p.cmd.Cancel = p.cmdStopUpstreamProcess p.cmd.WaitDelay = p.gracefulStopTimeout + + p.cmdMutex.Lock() p.cancelUpstream = ctxCancelUpstream p.cmdWaitChan = make(chan struct{}) + p.cmdMutex.Unlock() p.failedStartCount++ // this will be reset to zero when the process has successfully started @@ -212,7 +249,7 @@ func (p *Process) start() error { // Set process state to failed if err != nil { if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil { - p.state = StateStopped // force it into a stopped state + p.forceState(StateStopped) // force it into a stopped state return fmt.Errorf( "failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v", strings.Join(args, " "), err, curState, swapErr, @@ -285,10 +322,12 @@ func (p *Process) start() error { return } - // wait for all inflight requests to complete and ticker - p.inFlightRequests.Wait() + // skip the TTL check if there are inflight requests + if p.inFlightRequestsCount.Load() != 0 { + continue + } - if time.Since(p.lastRequestHandled) > maxDuration { + if time.Since(p.getLastRequestHandled()) > maxDuration { p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter) p.Stop() return @@ -344,7 +383,7 @@ func (p *Process) Shutdown() { p.stopCommand() // just force it to this state since there is no recovery from shutdown - p.state = StateShutdown + p.forceState(StateShutdown) } // stopCommand will send a SIGTERM to the process and wait for it to exit. @@ -355,13 +394,18 @@ func (p *Process) stopCommand() { p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime)) }() - if p.cancelUpstream == nil { + p.cmdMutex.RLock() + cancelUpstream := p.cancelUpstream + cmdWaitChan := p.cmdWaitChan + p.cmdMutex.RUnlock() + + if cancelUpstream == nil { p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID) return } - p.cancelUpstream() - <-p.cmdWaitChan + cancelUpstream() + <-cmdWaitChan } func (p *Process) checkHealthEndpoint(healthURL string) error { @@ -418,8 +462,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { } p.inFlightRequests.Add(1) + p.inFlightRequestsCount.Add(1) defer func() { - p.lastRequestHandled = time.Now() + p.setLastRequestHandled(time.Now()) + p.inFlightRequestsCount.Add(-1) p.inFlightRequests.Done() }() @@ -519,13 +565,16 @@ func (p *Process) waitForCmd() { case StateStopping: if curState, err := p.swapState(StateStopping, StateStopped); err != nil { p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err) - p.state = StateStopped + p.forceState(StateStopped) } default: p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState) - p.state = StateStopped // force it to be in this state + p.forceState(StateStopped) // force it to be in this state } + + p.cmdMutex.Lock() close(p.cmdWaitChan) + p.cmdMutex.Unlock() } // cmdStopUpstreamProcess attemps to stop the upstream process gracefully diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 3c8a158..b6462cf 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -1075,18 +1075,28 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) { for _, endpoint := range endpoints { t.Run(endpoint, func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() req := httptest.NewRequest("GET", endpoint, nil) req = req.WithContext(ctx) rec := httptest.NewRecorder() - // We don't need the handler to fully complete, just to set the headers - // so run it in a goroutine and check the headers after a short delay - go proxy.ServeHTTP(rec, req) - time.Sleep(10 * time.Millisecond) // give it time to start and write headers + // Run handler in goroutine and wait for context timeout + done := make(chan struct{}) + go func() { + defer close(done) + proxy.ServeHTTP(rec, req) + }() + // Wait for either the handler to complete or context to timeout + <-ctx.Done() + + // At this point, the handler has either finished or been cancelled + // Wait for the goroutine to fully exit before reading + <-done + + // Now it's safe to read from rec - no more concurrent writes assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering")) })