delay TTL check until after all requests are complete (#25)

- fixes #25 where requests that last longer than the TTL will cause the
  process to be unloaded before the next request.
- new behavior, TTL waits until all requests are complete before
  checking timeout
This commit is contained in:
Benson Wong
2024-12-09 19:08:03 -08:00
parent 97dae50dc4
commit 5fbd53c616
2 changed files with 26 additions and 12 deletions

View File

@@ -122,16 +122,15 @@ func (p *Process) start() error {
// start a goroutine to check every second if // start a goroutine to check every second if
// the process should be stopped // the process should be stopped
go func() { go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for { for range time.Tick(time.Second) {
<-ticker.C // wait for all inflight requests to complete and ticker
p.inFlightRequests.Wait()
if time.Since(p.lastRequestHandled) > maxDuration { if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter) fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter)
p.Stop() p.Stop()
return
} }
} }
}() }()
@@ -275,7 +274,11 @@ func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
p.inFlightRequests.Add(1) p.inFlightRequests.Add(1)
defer p.inFlightRequests.Done()
defer func() {
p.lastRequestHandled = time.Now()
p.inFlightRequests.Done()
}()
if p.CurrentState() != StateReady { if p.CurrentState() != StateReady {
if err := p.start(); err != nil { if err := p.start(); err != nil {
@@ -285,8 +288,6 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
} }
} }
p.lastRequestHandled = time.Now()
proxyTo := p.config.Proxy proxyTo := p.config.Proxy
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body) req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)

View File

@@ -82,18 +82,31 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard)) process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
defer process.Stop() defer process.Stop()
req := httptest.NewRequest("GET", "/test", nil) // this should take 4 seconds
req1 := httptest.NewRequest("GET", "/slow-respond?echo=1234&delay=1000ms", nil)
req2 := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
// Proxy the request (auto start) // Proxy the request (auto start) with a slow response that takes longer than config.UnloadAfter
process.ProxyRequest(w, req) process.ProxyRequest(w, req1)
t.Log("sending slow first request (4 seconds)")
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "1234")
assert.Equal(t, StateReady, process.CurrentState())
// ensure the TTL timeout does not race slow requests (see issue #25)
t.Log("sending second request (1 second)")
time.Sleep(time.Second)
w = httptest.NewRecorder()
process.ProxyRequest(w, req2)
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expectedMessage) assert.Contains(t, w.Body.String(), expectedMessage)
assert.Equal(t, StateReady, process.CurrentState()) assert.Equal(t, StateReady, process.CurrentState())
// wait 5 seconds // wait 5 seconds
t.Log("sleep 5 seconds and check if unloaded")
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
assert.Equal(t, StateStopped, process.CurrentState()) assert.Equal(t, StateStopped, process.CurrentState())
} }