Improve Concurrency and Parallel Request Handling (#19)

Rewrite the swap behaviour so that in-flight requests block process swapping until they are completed. 

Additionally: 

- add tests for parallel requests with proxy.ProxyManager and proxy.Process
- improve Process startup behaviour and simplified the code 
- stopping of processes are sent SIGTERM and have 5 seconds to terminate, before they are killed
This commit is contained in:
Benson Wong
2024-11-30 15:24:42 -08:00
committed by GitHub
parent e363f8f498
commit cf82b3c633
6 changed files with 331 additions and 64 deletions

View File

@@ -3,60 +3,137 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"io"
"log"
"net/http" "net/http"
"os" "os"
"os/signal"
"syscall"
"time"
"github.com/gin-gonic/gin"
) )
func main() { func main() {
gin.SetMode(gin.TestMode)
// Define a command-line flag for the port // Define a command-line flag for the port
port := flag.String("port", "8080", "port to listen on") port := flag.String("port", "8080", "port to listen on")
// Define a command-line flag for the response message // Define a command-line flag for the response message
responseMessage := flag.String("respond", "hi", "message to respond with") responseMessage := flag.String("respond", "hi", "message to respond with")
silent := flag.Bool("silent", false, "disable all logging")
flag.Parse() // Parse the command-line flags flag.Parse() // Parse the command-line flags
responseMessageHandler := func(w http.ResponseWriter, r *http.Request) { // Create a new Gin router
// Set the header to text/plain r := gin.New()
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintln(w, *responseMessage)
}
// Set up the handler function using the provided response message // Set up the handler function using the provided response message
http.HandleFunc("/v1/chat/completions", responseMessageHandler) r.POST("/v1/chat/completions", func(c *gin.Context) {
http.HandleFunc("/v1/completions", responseMessageHandler) c.Header("Content-Type", "text/plain")
http.HandleFunc("/test", responseMessageHandler)
http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) { // add a wait to simulate a slow query
w.Header().Set("Content-Type", "text/plain") if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
fmt.Fprintln(w, *responseMessage) time.Sleep(wait)
}
c.String(200, *responseMessage)
})
r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
})
r.GET("/slow-respond", func(c *gin.Context) {
echo := c.Query("echo")
delay := c.Query("delay")
if echo == "" {
echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
}
// Parse the duration
if delay == "" {
delay = "100ms"
}
t, err := time.ParseDuration(delay)
if err != nil {
c.Header("Content-Type", "text/plain")
c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err))
return
}
c.Header("Content-Type", "text/plain")
for _, char := range echo {
c.Writer.Write([]byte(string(char)))
c.Writer.Flush()
// wait
<-time.After(t)
}
})
r.GET("/test", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
})
r.GET("/env", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)
// Get environment variables // Get environment variables
envVars := os.Environ() envVars := os.Environ()
// Write each environment variable to the response // Write each environment variable to the response
for _, envVar := range envVars { for _, envVar := range envVars {
fmt.Fprintln(w, envVar) c.String(200, envVar)
} }
}) })
// Set up the /health endpoint handler function // Set up the /health endpoint handler function
http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { r.GET("/health", func(c *gin.Context) {
w.Header().Set("Content-Type", "application/json") c.Header("Content-Type", "application/json")
response := `{"status": "ok"}` c.JSON(200, gin.H{"status": "ok"})
w.Write([]byte(response))
}) })
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { r.GET("/", func(c *gin.Context) {
w.Header().Set("Content-Type", "text/plain") c.Header("Content-Type", "text/plain")
fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path) c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path))
}) })
address := "127.0.0.1:" + *port // Address with the specified port address := "127.0.0.1:" + *port // Address with the specified port
fmt.Printf("Server is listening on port %s\n", *port)
// Start the server and log any error if it occurs srv := &http.Server{
if err := http.ListenAndServe(address, nil); err != nil { Addr: address,
fmt.Printf("Error starting server: %s\n", err) Handler: r.Handler(),
} }
// Disable logging if the --silent flag is set
if *silent {
gin.SetMode(gin.ReleaseMode)
gin.DefaultWriter = io.Discard
log.SetOutput(io.Discard)
}
go func() {
log.Printf("simple-responder listening on %s\n", address)
// service connections
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("simple-responder err: %s\n", err)
}
}()
// Wait for interrupt signal to gracefully shutdown the server with
// a timeout of 5 seconds.
quit := make(chan os.Signal, 1)
// kill (no param) default send syscall.SIGTERM
// kill -2 is syscall.SIGINT
// kill -9 is syscall.SIGKILL but can't be catch, so don't need add it
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("simple-responder shutting down")
} }

View File

@@ -51,7 +51,7 @@ func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelCon
// Create a process configuration // Create a process configuration
return ModelConfig{ return ModelConfig{
Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, port, expectedMessage), Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), Proxy: fmt.Sprintf("http://127.0.0.1:%d", port),
CheckEndpoint: "/health", CheckEndpoint: "/health",
} }

View File

@@ -14,6 +14,14 @@ import (
"time" "time"
) )
type ProcessState string
const (
StateStopped ProcessState = ProcessState("stopped")
StateReady ProcessState = ProcessState("ready")
StateFailed ProcessState = ProcessState("failed")
)
type Process struct { type Process struct {
sync.Mutex sync.Mutex
@@ -23,8 +31,12 @@ type Process struct {
logMonitor *LogMonitor logMonitor *LogMonitor
healthCheckTimeout int healthCheckTimeout int
isRunning bool
lastRequestHandled time.Time lastRequestHandled time.Time
stateMutex sync.RWMutex
state ProcessState
inFlightRequests sync.WaitGroup
} }
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
@@ -34,16 +46,22 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito
cmd: nil, cmd: nil,
logMonitor: logMonitor, logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout, healthCheckTimeout: healthCheckTimeout,
state: StateStopped,
} }
} }
// start the process and check it for errors // start the process and returns when it is ready
func (p *Process) start() error { func (p *Process) start() error {
p.Lock()
defer p.Unlock()
if p.isRunning { p.stateMutex.Lock()
return fmt.Errorf("process already running") defer p.stateMutex.Unlock()
if p.state == StateReady {
return nil
}
if p.state == StateFailed {
return fmt.Errorf("process is in a failed state and can not be restarted")
} }
args, err := p.config.SanitizedCommand() args, err := p.config.SanitizedCommand()
@@ -57,34 +75,47 @@ func (p *Process) start() error {
p.cmd.Env = p.config.Env p.cmd.Env = p.config.Env
err = p.cmd.Start() err = p.cmd.Start()
p.isRunning = true
if err != nil { if err != nil {
return err return err
} }
// watch for the command to exit // One of three things can happen at this stage:
cmdCtx, cancel := context.WithCancelCause(context.Background()) // 1. The command exits unexpectedly
// 2. The health check fails
// 3. The health check passes
//
// only in the third case will the process be considered Ready to accept
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
defer cancelHealthCheck(nil) // clean up
cmdWaitChan := make(chan error, 1)
healthCheckChan := make(chan error, 1)
// monitor the command's exit status. Usually this happens if
// the process exited unexpectedly
go func() { go func() {
err := p.cmd.Wait() // possible cmd exits early
if err != nil { cmdWaitChan <- p.cmd.Wait()
cancel(fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error()))
} else {
cancel(nil)
}
p.isRunning = false
}() }()
// wait a bit for process to start before checking the health endpoint go func() {
time.Sleep(250 * time.Millisecond) <-time.After(250 * time.Millisecond) // give process a bit of time to start
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
}()
// wait for checkHealthEndpoint select {
if err := p.checkHealthEndpoint(cmdCtx); err != nil { case err := <-cmdWaitChan:
p.state = StateFailed
if err != nil {
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
} else {
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
}
cancelHealthCheck(err)
return err return err
case err := <-healthCheckChan:
if err != nil {
p.state = StateFailed
return err
}
} }
if p.config.UnloadAfter > 0 { if p.config.UnloadAfter > 0 {
@@ -106,27 +137,64 @@ func (p *Process) start() error {
}() }()
} }
p.state = StateReady
return nil return nil
} }
func (p *Process) Stop() { func (p *Process) Stop() {
p.Lock() // wait for any inflight requests before proceeding
defer p.Unlock() p.inFlightRequests.Wait()
if !p.isRunning || p.cmd == nil || p.cmd.Process == nil { p.stateMutex.Lock()
defer p.stateMutex.Unlock()
if p.state != StateReady {
return return
} }
if p.cmd == nil || p.cmd.Process == nil {
// this situation should never happen... but if it does just update the state
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
p.state = StateStopped
return
}
// Pretty sure this stopping code needs some work for windows and
// will be a source of pain in the future.
p.cmd.Process.Signal(syscall.SIGTERM) p.cmd.Process.Signal(syscall.SIGTERM)
p.cmd.Process.Wait() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
p.isRunning = false defer cancel()
done := make(chan error, 1)
go func() {
done <- p.cmd.Wait()
}()
select {
case <-ctx.Done():
fmt.Printf("!!! process for %s timed out waiting to stop\n", p.ID)
p.cmd.Process.Kill()
p.cmd.Wait()
case err := <-done:
if err != nil {
if err.Error() != "wait: no child processes" {
// possible that simple-responder for testing is just not
// existing right, so suppress those errors.
fmt.Printf("!!! process for %s stopped with error > %v\n", p.ID, err)
}
}
}
p.state = StateStopped
} }
func (p *Process) IsRunning() bool { func (p *Process) CurrentState() ProcessState {
return p.isRunning p.stateMutex.RLock()
defer p.stateMutex.RUnlock()
return p.state
} }
func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error { func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
if p.config.Proxy == "" { if p.config.Proxy == "" {
return fmt.Errorf("no upstream available to check /health") return fmt.Errorf("no upstream available to check /health")
} }
@@ -158,7 +226,7 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error {
return err return err
} }
ctx, cancel := context.WithTimeout(cmdCtx, time.Second) ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
defer cancel() defer cancel()
req = req.WithContext(ctx) req = req.WithContext(ctx)
resp, err := client.Do(req) resp, err := client.Do(req)
@@ -205,7 +273,11 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error {
} }
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
if !p.isRunning {
p.inFlightRequests.Add(1)
defer p.inFlightRequests.Done()
if p.CurrentState() != StateReady {
if err := p.start(); err != nil { if err := p.start(); err != nil {
errstr := fmt.Sprintf("unable to start process: %s", err) errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusInternalServerError) http.Error(w, errstr, http.StatusInternalServerError)

View File

@@ -1,9 +1,12 @@
package proxy package proxy
import ( import (
"fmt"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"sync"
"testing" "testing"
"time" "time"
@@ -23,9 +26,9 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
// process is automatically started // process is automatically started
assert.False(t, process.IsRunning()) assert.Equal(t, StateStopped, process.CurrentState())
process.ProxyRequest(w, req) process.ProxyRequest(w, req)
assert.True(t, process.IsRunning()) assert.Equal(t, StateReady, process.CurrentState())
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expectedMessage) assert.Contains(t, w.Body.String(), expectedMessage)
@@ -84,13 +87,67 @@ func TestProcess_UnloadAfterTTL(t *testing.T) {
// Proxy the request (auto start) // Proxy the request (auto start)
process.ProxyRequest(w, req) process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expectedMessage) assert.Contains(t, w.Body.String(), expectedMessage)
assert.True(t, process.IsRunning()) assert.Equal(t, StateReady, process.CurrentState())
// wait 5 seconds // wait 5 seconds
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
assert.Equal(t, StateStopped, process.CurrentState())
assert.False(t, process.IsRunning()) }
// issue #19
func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) {
if testing.Short() {
t.Skip("skipping long test")
}
expectedMessage := "12345"
config := getTestSimpleResponderConfig(expectedMessage)
process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout))
defer process.Stop()
results := map[string]string{
"12345": "",
"abcde": "",
"fghij": "",
}
var wg sync.WaitGroup
var mu sync.Mutex
for key := range results {
wg.Add(1)
go func(key string) {
defer wg.Done()
// send a request that should take 5 * 200ms (1 second) to complete
req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
}
mu.Lock()
results[key] = w.Body.String()
mu.Unlock()
}(key)
}
// stop the requests in the middle
go func() {
<-time.After(500 * time.Millisecond)
process.Stop()
}()
wg.Wait()
for key, result := range results {
assert.Equal(t, key, result)
}
} }

View File

@@ -69,9 +69,14 @@ func (pm *ProxyManager) StopProcesses() {
// for internal usage // for internal usage
func (pm *ProxyManager) stopProcesses() { func (pm *ProxyManager) stopProcesses() {
if len(pm.currentProcesses) == 0 {
return
}
for _, process := range pm.currentProcesses { for _, process := range pm.currentProcesses {
process.Stop() process.Stop()
} }
pm.currentProcesses = make(map[string]*Process) pm.currentProcesses = make(map[string]*Process)
} }
@@ -185,7 +190,6 @@ func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
process.ProxyRequest(c.Writer, c.Request) process.ProxyRequest(c.Writer, c.Request)
} }
} }
func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) { func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {

View File

@@ -5,7 +5,9 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"sync"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
@@ -72,5 +74,60 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
_, exists = proxy.currentProcesses["test/model2"] _, exists = proxy.currentProcesses["test/model2"]
assert.True(t, exists, "expected test/model2 key in currentProcesses") assert.True(t, exists, "expected test/model2 key in currentProcesses")
}
// When a request for a different model comes in ProxyManager should wait until
// the first request is complete before swapping. Both requests should complete
func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
if testing.Short() {
t.Skip("skipping slow test")
}
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
results := map[string]string{}
var wg sync.WaitGroup
var mu sync.Mutex
for key := range config.Models {
wg.Add(1)
go func(key string) {
defer wg.Done()
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
}
mu.Lock()
results[key] = w.Body.String()
mu.Unlock()
}(key)
<-time.After(time.Millisecond)
}
wg.Wait()
assert.Len(t, results, len(config.Models))
for key, result := range results {
assert.Equal(t, key, result)
}
} }