diff --git a/misc/simple-responder/simple-responder.go b/misc/simple-responder/simple-responder.go index 6c9f235..1108edd 100644 --- a/misc/simple-responder/simple-responder.go +++ b/misc/simple-responder/simple-responder.go @@ -3,60 +3,137 @@ package main import ( "flag" "fmt" + "io" + "log" "net/http" "os" + "os/signal" + "syscall" + "time" + + "github.com/gin-gonic/gin" ) func main() { + gin.SetMode(gin.TestMode) // Define a command-line flag for the port port := flag.String("port", "8080", "port to listen on") // Define a command-line flag for the response message responseMessage := flag.String("respond", "hi", "message to respond with") + silent := flag.Bool("silent", false, "disable all logging") + flag.Parse() // Parse the command-line flags - responseMessageHandler := func(w http.ResponseWriter, r *http.Request) { - // Set the header to text/plain - w.Header().Set("Content-Type", "text/plain") - fmt.Fprintln(w, *responseMessage) - } + // Create a new Gin router + r := gin.New() // Set up the handler function using the provided response message - http.HandleFunc("/v1/chat/completions", responseMessageHandler) - http.HandleFunc("/v1/completions", responseMessageHandler) - http.HandleFunc("/test", responseMessageHandler) + r.POST("/v1/chat/completions", func(c *gin.Context) { + c.Header("Content-Type", "text/plain") - http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - fmt.Fprintln(w, *responseMessage) + // add a wait to simulate a slow query + if wait, err := time.ParseDuration(c.Query("wait")); err == nil { + time.Sleep(wait) + } + + c.String(200, *responseMessage) + }) + + r.POST("/v1/completions", func(c *gin.Context) { + c.Header("Content-Type", "text/plain") + c.String(200, *responseMessage) + }) + + r.GET("/slow-respond", func(c *gin.Context) { + echo := c.Query("echo") + delay := c.Query("delay") + + if echo == "" { + echo = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + } + + // Parse the duration + if delay == "" { + delay = "100ms" + } + + t, err := time.ParseDuration(delay) + if err != nil { + c.Header("Content-Type", "text/plain") + c.String(http.StatusBadRequest, fmt.Sprintf("Invalid duration: %s", err)) + return + } + + c.Header("Content-Type", "text/plain") + for _, char := range echo { + c.Writer.Write([]byte(string(char))) + c.Writer.Flush() + + // wait + <-time.After(t) + } + }) + + r.GET("/test", func(c *gin.Context) { + c.Header("Content-Type", "text/plain") + c.String(200, *responseMessage) + }) + + r.GET("/env", func(c *gin.Context) { + c.Header("Content-Type", "text/plain") + c.String(200, *responseMessage) // Get environment variables envVars := os.Environ() // Write each environment variable to the response for _, envVar := range envVars { - fmt.Fprintln(w, envVar) + c.String(200, envVar) } }) // Set up the /health endpoint handler function - http.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - response := `{"status": "ok"}` - w.Write([]byte(response)) + r.GET("/health", func(c *gin.Context) { + c.Header("Content-Type", "application/json") + c.JSON(200, gin.H{"status": "ok"}) }) - http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain") - fmt.Fprintf(w, "%s %s", r.Method, r.URL.Path) + r.GET("/", func(c *gin.Context) { + c.Header("Content-Type", "text/plain") + c.String(200, fmt.Sprintf("%s %s", c.Request.Method, c.Request.URL.Path)) }) address := "127.0.0.1:" + *port // Address with the specified port - fmt.Printf("Server is listening on port %s\n", *port) - // Start the server and log any error if it occurs - if err := http.ListenAndServe(address, nil); err != nil { - fmt.Printf("Error starting server: %s\n", err) + srv := &http.Server{ + Addr: address, + Handler: r.Handler(), } + + // Disable logging if the --silent flag is set + if *silent { + gin.SetMode(gin.ReleaseMode) + gin.DefaultWriter = io.Discard + log.SetOutput(io.Discard) + } + + go func() { + log.Printf("simple-responder listening on %s\n", address) + // service connections + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("simple-responder err: %s\n", err) + } + }() + + // Wait for interrupt signal to gracefully shutdown the server with + // a timeout of 5 seconds. + quit := make(chan os.Signal, 1) + // kill (no param) default send syscall.SIGTERM + // kill -2 is syscall.SIGINT + // kill -9 is syscall.SIGKILL but can't be catch, so don't need add it + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Println("simple-responder shutting down") } diff --git a/proxy/helpers_test.go b/proxy/helpers_test.go index e8efa5b..d6982c5 100644 --- a/proxy/helpers_test.go +++ b/proxy/helpers_test.go @@ -51,7 +51,7 @@ func getTestSimpleResponderConfigPort(expectedMessage string, port int) ModelCon // Create a process configuration return ModelConfig{ - Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, port, expectedMessage), + Cmd: fmt.Sprintf("%s --port %d --silent --respond %s", binaryPath, port, expectedMessage), Proxy: fmt.Sprintf("http://127.0.0.1:%d", port), CheckEndpoint: "/health", } diff --git a/proxy/process.go b/proxy/process.go index 10212f3..808917b 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -14,6 +14,14 @@ import ( "time" ) +type ProcessState string + +const ( + StateStopped ProcessState = ProcessState("stopped") + StateReady ProcessState = ProcessState("ready") + StateFailed ProcessState = ProcessState("failed") +) + type Process struct { sync.Mutex @@ -23,8 +31,12 @@ type Process struct { logMonitor *LogMonitor healthCheckTimeout int - isRunning bool lastRequestHandled time.Time + + stateMutex sync.RWMutex + state ProcessState + + inFlightRequests sync.WaitGroup } func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process { @@ -34,16 +46,22 @@ func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonito cmd: nil, logMonitor: logMonitor, healthCheckTimeout: healthCheckTimeout, + state: StateStopped, } } -// start the process and check it for errors +// start the process and returns when it is ready func (p *Process) start() error { - p.Lock() - defer p.Unlock() - if p.isRunning { - return fmt.Errorf("process already running") + p.stateMutex.Lock() + defer p.stateMutex.Unlock() + + if p.state == StateReady { + return nil + } + + if p.state == StateFailed { + return fmt.Errorf("process is in a failed state and can not be restarted") } args, err := p.config.SanitizedCommand() @@ -57,34 +75,47 @@ func (p *Process) start() error { p.cmd.Env = p.config.Env err = p.cmd.Start() - p.isRunning = true if err != nil { return err } - // watch for the command to exit - cmdCtx, cancel := context.WithCancelCause(context.Background()) + // One of three things can happen at this stage: + // 1. The command exits unexpectedly + // 2. The health check fails + // 3. The health check passes + // + // only in the third case will the process be considered Ready to accept + healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background()) + defer cancelHealthCheck(nil) // clean up + cmdWaitChan := make(chan error, 1) + healthCheckChan := make(chan error, 1) - // monitor the command's exit status. Usually this happens if - // the process exited unexpectedly go func() { - err := p.cmd.Wait() - if err != nil { - cancel(fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())) - } else { - cancel(nil) - } - - p.isRunning = false + // possible cmd exits early + cmdWaitChan <- p.cmd.Wait() }() - // wait a bit for process to start before checking the health endpoint - time.Sleep(250 * time.Millisecond) + go func() { + <-time.After(250 * time.Millisecond) // give process a bit of time to start + healthCheckChan <- p.checkHealthEndpoint(healthCheckContext) + }() - // wait for checkHealthEndpoint - if err := p.checkHealthEndpoint(cmdCtx); err != nil { + select { + case err := <-cmdWaitChan: + p.state = StateFailed + if err != nil { + err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error()) + } else { + err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " ")) + } + cancelHealthCheck(err) return err + case err := <-healthCheckChan: + if err != nil { + p.state = StateFailed + return err + } } if p.config.UnloadAfter > 0 { @@ -106,27 +137,64 @@ func (p *Process) start() error { }() } + p.state = StateReady return nil } func (p *Process) Stop() { - p.Lock() - defer p.Unlock() + // wait for any inflight requests before proceeding + p.inFlightRequests.Wait() - if !p.isRunning || p.cmd == nil || p.cmd.Process == nil { + p.stateMutex.Lock() + defer p.stateMutex.Unlock() + + if p.state != StateReady { return } + if p.cmd == nil || p.cmd.Process == nil { + // this situation should never happen... but if it does just update the state + fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.") + p.state = StateStopped + return + } + + // Pretty sure this stopping code needs some work for windows and + // will be a source of pain in the future. + p.cmd.Process.Signal(syscall.SIGTERM) - p.cmd.Process.Wait() - p.isRunning = false + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + done := make(chan error, 1) + go func() { + done <- p.cmd.Wait() + }() + + select { + case <-ctx.Done(): + fmt.Printf("!!! process for %s timed out waiting to stop\n", p.ID) + p.cmd.Process.Kill() + p.cmd.Wait() + case err := <-done: + if err != nil { + if err.Error() != "wait: no child processes" { + // possible that simple-responder for testing is just not + // existing right, so suppress those errors. + fmt.Printf("!!! process for %s stopped with error > %v\n", p.ID, err) + } + } + } + p.state = StateStopped } -func (p *Process) IsRunning() bool { - return p.isRunning +func (p *Process) CurrentState() ProcessState { + p.stateMutex.RLock() + defer p.stateMutex.RUnlock() + return p.state } -func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error { +func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error { if p.config.Proxy == "" { return fmt.Errorf("no upstream available to check /health") } @@ -158,7 +226,7 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error { return err } - ctx, cancel := context.WithTimeout(cmdCtx, time.Second) + ctx, cancel := context.WithTimeout(ctxFromStart, time.Second) defer cancel() req = req.WithContext(ctx) resp, err := client.Do(req) @@ -205,7 +273,11 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error { } func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { - if !p.isRunning { + + p.inFlightRequests.Add(1) + defer p.inFlightRequests.Done() + + if p.CurrentState() != StateReady { if err := p.start(); err != nil { errstr := fmt.Sprintf("unable to start process: %s", err) http.Error(w, errstr, http.StatusInternalServerError) diff --git a/proxy/process_test.go b/proxy/process_test.go index 542f575..8d41771 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -1,9 +1,12 @@ package proxy import ( + "fmt" "io" "net/http" "net/http/httptest" + "os" + "sync" "testing" "time" @@ -23,9 +26,9 @@ func TestProcess_AutomaticallyStartsUpstream(t *testing.T) { w := httptest.NewRecorder() // process is automatically started - assert.False(t, process.IsRunning()) + assert.Equal(t, StateStopped, process.CurrentState()) process.ProxyRequest(w, req) - assert.True(t, process.IsRunning()) + assert.Equal(t, StateReady, process.CurrentState()) assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), expectedMessage) @@ -84,13 +87,67 @@ func TestProcess_UnloadAfterTTL(t *testing.T) { // Proxy the request (auto start) process.ProxyRequest(w, req) + assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), expectedMessage) - assert.True(t, process.IsRunning()) + assert.Equal(t, StateReady, process.CurrentState()) // wait 5 seconds time.Sleep(5 * time.Second) - - assert.False(t, process.IsRunning()) + assert.Equal(t, StateStopped, process.CurrentState()) +} + +// issue #19 +func TestProcess_HTTPRequestsHaveTimeToFinish(t *testing.T) { + if testing.Short() { + t.Skip("skipping long test") + } + + expectedMessage := "12345" + config := getTestSimpleResponderConfig(expectedMessage) + process := NewProcess("t", 10, config, NewLogMonitorWriter(os.Stdout)) + defer process.Stop() + + results := map[string]string{ + "12345": "", + "abcde": "", + "fghij": "", + } + + var wg sync.WaitGroup + var mu sync.Mutex + + for key := range results { + wg.Add(1) + go func(key string) { + defer wg.Done() + // send a request that should take 5 * 200ms (1 second) to complete + req := httptest.NewRequest("GET", fmt.Sprintf("/slow-respond?echo=%s&delay=200ms", key), nil) + w := httptest.NewRecorder() + + process.ProxyRequest(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status OK, got %d for key %s", w.Code, key) + } + + mu.Lock() + results[key] = w.Body.String() + mu.Unlock() + + }(key) + } + + // stop the requests in the middle + go func() { + <-time.After(500 * time.Millisecond) + process.Stop() + }() + + wg.Wait() + + for key, result := range results { + assert.Equal(t, key, result) + } } diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 3c3b874..8c5f271 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -69,9 +69,14 @@ func (pm *ProxyManager) StopProcesses() { // for internal usage func (pm *ProxyManager) stopProcesses() { + if len(pm.currentProcesses) == 0 { + return + } + for _, process := range pm.currentProcesses { process.Stop() } + pm.currentProcesses = make(map[string]*Process) } @@ -185,7 +190,6 @@ func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) { process.ProxyRequest(c.Writer, c.Request) } - } func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) { diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 2d15d7e..d17cdf7 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -5,7 +5,9 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -72,5 +74,60 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) { _, exists = proxy.currentProcesses["test/model2"] assert.True(t, exists, "expected test/model2 key in currentProcesses") - +} + +// When a request for a different model comes in ProxyManager should wait until +// the first request is complete before swapping. Both requests should complete +func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow test") + } + + config := &Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), + "model3": getTestSimpleResponderConfig("model3"), + }, + } + + proxy := New(config) + defer proxy.StopProcesses() + + results := map[string]string{} + + var wg sync.WaitGroup + var mu sync.Mutex + + for key := range config.Models { + wg.Add(1) + go func(key string) { + defer wg.Done() + + reqBody := fmt.Sprintf(`{"model":"%s"}`, key) + req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + proxy.HandlerFunc(w, req) + + if w.Code != http.StatusOK { + t.Errorf("Expected status OK, got %d for key %s", w.Code, key) + } + + mu.Lock() + + results[key] = w.Body.String() + mu.Unlock() + }(key) + + <-time.After(time.Millisecond) + } + + wg.Wait() + assert.Len(t, results, len(config.Models)) + + for key, result := range results { + assert.Equal(t, key, result) + } }