Improve handling of process that do not handle SIGTERM (#38)

- Process TTL goroutine did not have a return after .Stop()
- Improve logging
- Add test TestProcess_LowTTLValue to measure SIGTERM error rate
This commit is contained in:
Benson Wong
2025-01-20 14:39:52 -08:00
parent abdc2bfdb3
commit 2833517eef
2 changed files with 32 additions and 4 deletions

View File

@@ -135,6 +135,7 @@ func (p *Process) start() error {
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter) fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %ds reached.\n", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return
} }
} }
}() }()
@@ -165,7 +166,6 @@ func (p *Process) Stop() {
// Pretty sure this stopping code needs some work for windows and // Pretty sure this stopping code needs some work for windows and
// will be a source of pain in the future. // will be a source of pain in the future.
p.cmd.Process.Signal(syscall.SIGTERM)
sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) sigtermTimeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
@@ -174,9 +174,11 @@ func (p *Process) Stop() {
sigtermNormal <- p.cmd.Wait() sigtermNormal <- p.cmd.Wait()
}() }()
p.cmd.Process.Signal(syscall.SIGTERM)
select { select {
case <-sigtermTimeout.Done(): case <-sigtermTimeout.Done():
fmt.Fprintf(p.logMonitor, "!!! process for %s timed out waiting to stop\n", p.ID) fmt.Fprintf(p.logMonitor, "XXX Process for %s timed out waiting to stop, sending SIGKILL to PID: %d\n", p.ID, p.cmd.Process.Pid)
p.cmd.Process.Kill() p.cmd.Process.Kill()
p.cmd.Wait() p.cmd.Wait()
case err := <-sigtermNormal: case err := <-sigtermNormal:

View File

@@ -67,7 +67,6 @@ func TestProcess_BrokenModelConfig(t *testing.T) {
assert.Contains(t, w.Body.String(), "unable to start process") assert.Contains(t, w.Body.String(), "unable to start process")
} }
// test that the process unloads after the TTL
func TestProcess_UnloadAfterTTL(t *testing.T) { func TestProcess_UnloadAfterTTL(t *testing.T) {
if testing.Short() { if testing.Short() {
t.Skip("skipping long auto unload TTL test") t.Skip("skipping long auto unload TTL test")
@@ -79,7 +78,7 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
config.UnloadAfter = 3 // seconds config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter) assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard)) process := NewProcess("ttl_test", 2, config, NewLogMonitorWriter(io.Discard))
defer process.Stop() defer process.Stop()
// this should take 4 seconds // this should take 4 seconds
@@ -111,6 +110,33 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
assert.Equal(t, StateStopped, process.CurrentState()) assert.Equal(t, StateStopped, process.CurrentState())
} }
func TestProcess_LowTTLValue(t *testing.T) {
if true { // change this code to run this ...
t.Skip("skipping test, edit process_test.go to run it ")
}
config := getTestSimpleResponderConfig("fast_ttl")
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 1 // second
assert.Equal(t, 1, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop()
for i := 0; i < 100; i++ {
t.Logf("Waiting before sending request %d", i)
time.Sleep(1500 * time.Millisecond)
expected := fmt.Sprintf("echo=test_%d", i)
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=50ms", expected), nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expected)
}
}
// issue #19 // issue #19
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
if testing.Short() { if testing.Short() {