Fix race conditions in proxy.Process (#349)
- Fix data races found in proxy.Process by go's race detector. - Add data race detection to the CI tests. Fixes #348
This commit is contained in:
2
Makefile
2
Makefile
@@ -33,7 +33,7 @@ test: proxy/ui_dist/placeholder.txt
|
|||||||
|
|
||||||
# for CI - full test (takes longer)
|
# for CI - full test (takes longer)
|
||||||
test-all: proxy/ui_dist/placeholder.txt
|
test-all: proxy/ui_dist/placeholder.txt
|
||||||
go test -count=1 ./proxy/...
|
go test -race -count=1 ./proxy/...
|
||||||
|
|
||||||
ui/node_modules:
|
ui/node_modules:
|
||||||
cd ui && npm install
|
cd ui && npm install
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -44,6 +45,7 @@ type Process struct {
|
|||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
|
|
||||||
// PR #155 called to cancel the upstream process
|
// PR #155 called to cancel the upstream process
|
||||||
|
cmdMutex sync.RWMutex
|
||||||
cancelUpstream context.CancelFunc
|
cancelUpstream context.CancelFunc
|
||||||
|
|
||||||
// closed when command exits
|
// closed when command exits
|
||||||
@@ -55,12 +57,14 @@ type Process struct {
|
|||||||
healthCheckTimeout int
|
healthCheckTimeout int
|
||||||
healthCheckLoopInterval time.Duration
|
healthCheckLoopInterval time.Duration
|
||||||
|
|
||||||
lastRequestHandled time.Time
|
lastRequestHandledMutex sync.RWMutex
|
||||||
|
lastRequestHandled time.Time
|
||||||
|
|
||||||
stateMutex sync.RWMutex
|
stateMutex sync.RWMutex
|
||||||
state ProcessState
|
state ProcessState
|
||||||
|
|
||||||
inFlightRequests sync.WaitGroup
|
inFlightRequests sync.WaitGroup
|
||||||
|
inFlightRequestsCount atomic.Int32
|
||||||
|
|
||||||
// used to block on multiple start() calls
|
// used to block on multiple start() calls
|
||||||
waitStarting sync.WaitGroup
|
waitStarting sync.WaitGroup
|
||||||
@@ -107,6 +111,20 @@ func (p *Process) LogMonitor() *LogMonitor {
|
|||||||
return p.processLogger
|
return p.processLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setLastRequestHandled sets the last request handled time in a thread-safe manner.
|
||||||
|
func (p *Process) setLastRequestHandled(t time.Time) {
|
||||||
|
p.lastRequestHandledMutex.Lock()
|
||||||
|
defer p.lastRequestHandledMutex.Unlock()
|
||||||
|
p.lastRequestHandled = t
|
||||||
|
}
|
||||||
|
|
||||||
|
// getLastRequestHandled gets the last request handled time in a thread-safe manner.
|
||||||
|
func (p *Process) getLastRequestHandled() time.Time {
|
||||||
|
p.lastRequestHandledMutex.RLock()
|
||||||
|
defer p.lastRequestHandledMutex.RUnlock()
|
||||||
|
return p.lastRequestHandled
|
||||||
|
}
|
||||||
|
|
||||||
// custom error types for swapping state
|
// custom error types for swapping state
|
||||||
var (
|
var (
|
||||||
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
ErrExpectedStateMismatch = errors.New("expected state mismatch")
|
||||||
@@ -130,6 +148,13 @@ func (p *Process) swapState(expectedState, newState ProcessState) (ProcessState,
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.state = newState
|
p.state = newState
|
||||||
|
|
||||||
|
// Atomically increment waitStarting when entering StateStarting
|
||||||
|
// This ensures any thread that sees StateStarting will also see the WaitGroup counter incremented
|
||||||
|
if newState == StateStarting {
|
||||||
|
p.waitStarting.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
p.proxyLogger.Debugf("<%s> swapState() State transitioned from %s to %s", p.ID, expectedState, newState)
|
||||||
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState})
|
||||||
return p.state, nil
|
return p.state, nil
|
||||||
@@ -158,6 +183,15 @@ func (p *Process) CurrentState() ProcessState {
|
|||||||
return p.state
|
return p.state
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// forceState forces the process state to the new state with mutex protection.
|
||||||
|
// This should only be used in exceptional cases where the normal state transition
|
||||||
|
// validation via swapState() cannot be used.
|
||||||
|
func (p *Process) forceState(newState ProcessState) {
|
||||||
|
p.stateMutex.Lock()
|
||||||
|
defer p.stateMutex.Unlock()
|
||||||
|
p.state = newState
|
||||||
|
}
|
||||||
|
|
||||||
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
// start starts the upstream command, checks the health endpoint, and sets the state to Ready
|
||||||
// it is a private method because starting is automatic but stopping can be called
|
// it is a private method because starting is automatic but stopping can be called
|
||||||
// at any time.
|
// at any time.
|
||||||
@@ -191,7 +225,7 @@ func (p *Process) start() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
p.waitStarting.Add(1)
|
// waitStarting.Add(1) is now called atomically in swapState() when transitioning to StateStarting
|
||||||
defer p.waitStarting.Done()
|
defer p.waitStarting.Done()
|
||||||
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
cmdContext, ctxCancelUpstream := context.WithCancel(context.Background())
|
||||||
|
|
||||||
@@ -201,8 +235,11 @@ func (p *Process) start() error {
|
|||||||
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
p.cmd.Env = append(p.cmd.Environ(), p.config.Env...)
|
||||||
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
p.cmd.Cancel = p.cmdStopUpstreamProcess
|
||||||
p.cmd.WaitDelay = p.gracefulStopTimeout
|
p.cmd.WaitDelay = p.gracefulStopTimeout
|
||||||
|
|
||||||
|
p.cmdMutex.Lock()
|
||||||
p.cancelUpstream = ctxCancelUpstream
|
p.cancelUpstream = ctxCancelUpstream
|
||||||
p.cmdWaitChan = make(chan struct{})
|
p.cmdWaitChan = make(chan struct{})
|
||||||
|
p.cmdMutex.Unlock()
|
||||||
|
|
||||||
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
p.failedStartCount++ // this will be reset to zero when the process has successfully started
|
||||||
|
|
||||||
@@ -212,7 +249,7 @@ func (p *Process) start() error {
|
|||||||
// Set process state to failed
|
// Set process state to failed
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
if curState, swapErr := p.swapState(StateStarting, StateStopped); swapErr != nil {
|
||||||
p.state = StateStopped // force it into a stopped state
|
p.forceState(StateStopped) // force it into a stopped state
|
||||||
return fmt.Errorf(
|
return fmt.Errorf(
|
||||||
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
"failed to start command '%s' and state swap failed. command error: %v, current state: %v, state swap error: %v",
|
||||||
strings.Join(args, " "), err, curState, swapErr,
|
strings.Join(args, " "), err, curState, swapErr,
|
||||||
@@ -285,10 +322,12 @@ func (p *Process) start() error {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for all inflight requests to complete and ticker
|
// skip the TTL check if there are inflight requests
|
||||||
p.inFlightRequests.Wait()
|
if p.inFlightRequestsCount.Load() != 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if time.Since(p.lastRequestHandled) > maxDuration {
|
if time.Since(p.getLastRequestHandled()) > maxDuration {
|
||||||
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
p.proxyLogger.Infof("<%s> Unloading model, TTL of %ds reached", p.ID, p.config.UnloadAfter)
|
||||||
p.Stop()
|
p.Stop()
|
||||||
return
|
return
|
||||||
@@ -344,7 +383,7 @@ func (p *Process) Shutdown() {
|
|||||||
|
|
||||||
p.stopCommand()
|
p.stopCommand()
|
||||||
// just force it to this state since there is no recovery from shutdown
|
// just force it to this state since there is no recovery from shutdown
|
||||||
p.state = StateShutdown
|
p.forceState(StateShutdown)
|
||||||
}
|
}
|
||||||
|
|
||||||
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
// stopCommand will send a SIGTERM to the process and wait for it to exit.
|
||||||
@@ -355,13 +394,18 @@ func (p *Process) stopCommand() {
|
|||||||
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
p.proxyLogger.Debugf("<%s> stopCommand took %v", p.ID, time.Since(stopStartTime))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if p.cancelUpstream == nil {
|
p.cmdMutex.RLock()
|
||||||
|
cancelUpstream := p.cancelUpstream
|
||||||
|
cmdWaitChan := p.cmdWaitChan
|
||||||
|
p.cmdMutex.RUnlock()
|
||||||
|
|
||||||
|
if cancelUpstream == nil {
|
||||||
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
p.proxyLogger.Errorf("<%s> stopCommand has a nil p.cancelUpstream()", p.ID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.cancelUpstream()
|
cancelUpstream()
|
||||||
<-p.cmdWaitChan
|
<-cmdWaitChan
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||||
@@ -418,8 +462,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
p.inFlightRequests.Add(1)
|
p.inFlightRequests.Add(1)
|
||||||
|
p.inFlightRequestsCount.Add(1)
|
||||||
defer func() {
|
defer func() {
|
||||||
p.lastRequestHandled = time.Now()
|
p.setLastRequestHandled(time.Now())
|
||||||
|
p.inFlightRequestsCount.Add(-1)
|
||||||
p.inFlightRequests.Done()
|
p.inFlightRequests.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -519,13 +565,16 @@ func (p *Process) waitForCmd() {
|
|||||||
case StateStopping:
|
case StateStopping:
|
||||||
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
if curState, err := p.swapState(StateStopping, StateStopped); err != nil {
|
||||||
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
p.proxyLogger.Errorf("<%s> Process exited but could not swap to StateStopped. curState=%s, err: %v", p.ID, curState, err)
|
||||||
p.state = StateStopped
|
p.forceState(StateStopped)
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
p.proxyLogger.Infof("<%s> process exited but not StateStopping, current state: %s", p.ID, currentState)
|
||||||
p.state = StateStopped // force it to be in this state
|
p.forceState(StateStopped) // force it to be in this state
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.cmdMutex.Lock()
|
||||||
close(p.cmdWaitChan)
|
close(p.cmdWaitChan)
|
||||||
|
p.cmdMutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
// cmdStopUpstreamProcess attemps to stop the upstream process gracefully
|
||||||
|
|||||||
@@ -1075,18 +1075,28 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
|||||||
|
|
||||||
for _, endpoint := range endpoints {
|
for _, endpoint := range endpoints {
|
||||||
t.Run(endpoint, func(t *testing.T) {
|
t.Run(endpoint, func(t *testing.T) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", endpoint, nil)
|
req := httptest.NewRequest("GET", endpoint, nil)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
|
|
||||||
// We don't need the handler to fully complete, just to set the headers
|
// Run handler in goroutine and wait for context timeout
|
||||||
// so run it in a goroutine and check the headers after a short delay
|
done := make(chan struct{})
|
||||||
go proxy.ServeHTTP(rec, req)
|
go func() {
|
||||||
time.Sleep(10 * time.Millisecond) // give it time to start and write headers
|
defer close(done)
|
||||||
|
proxy.ServeHTTP(rec, req)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Wait for either the handler to complete or context to timeout
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
// At this point, the handler has either finished or been cancelled
|
||||||
|
// Wait for the goroutine to fully exit before reading
|
||||||
|
<-done
|
||||||
|
|
||||||
|
// Now it's safe to read from rec - no more concurrent writes
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
assert.Equal(t, "no", rec.Header().Get("X-Accel-Buffering"))
|
||||||
})
|
})
|
||||||
|
|||||||
Reference in New Issue
Block a user