From bc652709a536b4b789d6058ce7537deb67d7affe Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 12 May 2025 10:37:00 +1000 Subject: [PATCH] Add config hot-reload (#106) introduce --watch-config command line option to reload ProxyManager when configuration changes. --- README.md | 13 +++- go.mod | 1 + go.sum | 26 ++----- llama-swap.go | 148 ++++++++++++++++++++++++++++++++++--- proxy/process.go | 22 +++--- proxy/proxymanager.go | 21 +++--- proxy/proxymanager_test.go | 37 +++++----- 7 files changed, 196 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index 89f841a..b8c75fe 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,7 @@ groups: Docker is the quickest way to try out llama-swap: -``` +```shell # use CPU inference $ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu @@ -231,7 +231,7 @@ Specific versions are also available and are tagged with the llama-swap, archite Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration. -``` +```shell $ docker run -it --rm --runtime nvidia -p 9292:8080 \ -v /path/to/models:/models \ -v /path/to/custom/config.yaml:/app/config.yaml \ @@ -246,7 +246,12 @@ Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are 1. Create a configuration file, see [config.example.yaml](config.example.yaml) 1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture. -1. Run the binary with `llama-swap --config path/to/config.yaml` +1. Run the binary with `llama-swap --config path/to/config.yaml`. + Available flags: + - `--config`: Path to the configuration file (default: `config.yaml`). + - `--listen`: Address and port to listen on (default: `:8080`). + - `--version`: Show version information and exit. + - `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`). ### Building from source @@ -261,7 +266,7 @@ Open the `http:///logs` with your browser to get a web interface with stre Of course, CLI access is also supported: -``` +```shell # sends up to the last 10KB of logs curl http://host/logs' diff --git a/go.mod b/go.mod index e27169a..407979a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/mostlygeek/llama-swap go 1.23.0 require ( + github.com/fsnotify/fsnotify v1.9.0 github.com/gin-gonic/gin v1.10.0 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/stretchr/testify v1.9.0 diff --git a/go.sum b/go.sum index 9eb31fd..6ea5c34 100644 --- a/go.sum +++ b/go.sum @@ -9,12 +9,16 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= @@ -23,6 +27,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= @@ -74,34 +80,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= -golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI= -golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= -golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= -golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= -golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac= -golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= -golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= -golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c= -golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= -golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= -golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= -golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk= -golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= -golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/llama-swap.go b/llama-swap.go index c570220..ce9c8ea 100644 --- a/llama-swap.go +++ b/llama-swap.go @@ -1,25 +1,34 @@ package main import ( + "context" "flag" "fmt" + "log" + "net/http" "os" "os/signal" + "path/filepath" "syscall" + "time" + "github.com/fsnotify/fsnotify" "github.com/gin-gonic/gin" "github.com/mostlygeek/llama-swap/proxy" ) -var version string = "0" -var commit string = "abcd1234" -var date = "unknown" +var ( + version string = "0" + commit string = "abcd1234" + date string = "unknown" +) func main() { // Define a command-line flag for the port configPath := flag.String("config", "config.yaml", "config file name") listenStr := flag.String("listen", ":8080", "listen ip/port") showVersion := flag.Bool("version", false, "show version of build") + watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change") flag.Parse() // Parse the command-line flags @@ -46,18 +55,135 @@ func main() { proxyManager := proxy.New(config) + // Setup channels for server management + reloadChan := make(chan *proxy.ProxyManager) + exitChan := make(chan struct{}) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Create server with initial handler + srv := &http.Server{ + Addr: *listenStr, + Handler: proxyManager, + } + + // Start server + fmt.Printf("llama-swap listening on %s\n", *listenStr) go func() { - <-sigChan - fmt.Println("Shutting down llama-swap") - proxyManager.Shutdown() - os.Exit(0) + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + fmt.Printf("Fatal server error: %v\n", err) + close(exitChan) + } }() - fmt.Println("llama-swap listening on " + *listenStr) - if err := proxyManager.Run(*listenStr); err != nil { - fmt.Printf("Server error: %v\n", err) - os.Exit(1) + // Handle config reloads and signals + go func() { + currentManager := proxyManager + for { + select { + case newManager := <-reloadChan: + log.Println("Config change detected, waiting for in-flight requests to complete...") + // Stop old manager processes gracefully (this waits for in-flight requests) + currentManager.StopProcesses() + // Now do a full shutdown to clear the process map + currentManager.Shutdown() + currentManager = newManager + srv.Handler = newManager + log.Println("Server handler updated with new config") + case sig := <-sigChan: + fmt.Printf("Received signal %v, shutting down...\n", sig) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + currentManager.Shutdown() + if err := srv.Shutdown(ctx); err != nil { + fmt.Printf("Server shutdown error: %v\n", err) + } + close(exitChan) + return + } + } + }() + + // Start file watcher if requested + if *watchConfig { + absConfigPath, err := filepath.Abs(*configPath) + if err != nil { + log.Printf("Error getting absolute path for config: %v. File watching disabled.", err) + } else { + go watchConfigFileWithReload(absConfigPath, reloadChan) + } + } + + // Wait for exit signal + <-exitChan +} + +// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan. +func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Printf("Error creating file watcher: %v. File watching disabled.", err) + return + } + defer watcher.Close() + + err = watcher.Add(configPath) + if err != nil { + log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err) + return + } + + log.Printf("Watching config file for changes: %s", configPath) + + var debounceTimer *time.Timer + debounceDuration := 2 * time.Second + + for { + select { + case event, ok := <-watcher.Events: + if !ok { + return + } + // We only care about writes to the specific config file + if event.Name == configPath && event.Has(fsnotify.Write) { + // Reset or start the debounce timer + if debounceTimer != nil { + debounceTimer.Stop() + } + debounceTimer = time.AfterFunc(debounceDuration, func() { + log.Printf("Config file modified: %s, reloading...", event.Name) + + // Try up to 3 times with exponential backoff + var newConfig proxy.Config + var err error + for retries := 0; retries < 3; retries++ { + // Load new configuration + newConfig, err = proxy.LoadConfig(configPath) + if err == nil { + break + } + log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err) + if retries < 2 { + time.Sleep(time.Duration(1< cmd.Wait() returned error: %v", p.ID, exitErr) @@ -260,9 +260,9 @@ func (p *Process) start() error { if strings.Contains(err.Error(), "connection refused") { endTime, _ := checkDeadline.Deadline() ttl := time.Until(endTime) - p.proxyLogger.Infof("<%s> Connection refused on %s, giving up in %.0fs", p.ID, healthURL, ttl.Seconds()) + p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds()) } else { - p.proxyLogger.Infof("<%s> Health check error on %s, %v", p.ID, healthURL, err) + p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err) } } } @@ -345,31 +345,33 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) { defer cancelTimeout() if p.cmd == nil || p.cmd.Process == nil { - p.proxyLogger.Warnf("<%s> cmd or cmd.Process is nil", p.ID) + p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID) return } if err := p.terminateProcess(); err != nil { - p.proxyLogger.Infof("<%s> Failed to gracefully terminate process: %v", p.ID, err) + p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err) } select { case <-sigtermTimeout.Done(): - p.proxyLogger.Infof("<%s> Process timed out waiting to stop, sending KILL signal", p.ID) - p.cmd.Process.Kill() + p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID) + if err := p.cmd.Process.Kill(); err != nil { + p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err) + } case err := <-p.cmdWaitChan: // Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK // because if we make it here then the cmd has been successfully running and made it - // through the health check. There is a possibility that ithe cmd crashed after the health check + // through the health check. There is a possibility that the cmd crashed after the health check // succeeded but that's not a case llama-swap is handling for now. if err != nil { if errno, ok := err.(syscall.Errno); ok { p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno) } else if exitError, ok := err.(*exec.ExitError); ok { if strings.Contains(exitError.String(), "signal: terminated") { - p.proxyLogger.Infof("<%s> Process stopped OK", p.ID) + p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID) } else if strings.Contains(exitError.String(), "signal: interrupt") { - p.proxyLogger.Infof("<%s> Process interrupted OK", p.ID) + p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID) } else { p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode()) } diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index f1c5661..54fc8a7 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -82,6 +82,11 @@ func New(config Config) *ProxyManager { pm.processGroups[groupID] = processGroup } + pm.setupGinEngine() + return pm +} + +func (pm *ProxyManager) setupGinEngine() { pm.ginEngine.Use(func(c *gin.Context) { // Start timer start := time.Now() @@ -192,18 +197,17 @@ func New(config Config) *ProxyManager { // Disable console color for testing gin.DisableConsoleColor() - - return pm } -func (pm *ProxyManager) Run(addr ...string) error { - return pm.ginEngine.Run(addr...) -} - -func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) { +// ServeHTTP implements http.Handler interface +func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) { pm.ginEngine.ServeHTTP(w, r) } +// StopProcesses acquires a lock and stops all running upstream processes. +// This is the public method safe for concurrent calls. +// Unlike Shutdown, this method only stops the processes but doesn't perform +// a complete shutdown, allowing for process replacement without full termination. func (pm *ProxyManager) StopProcesses() { pm.Lock() defer pm.Unlock() @@ -221,8 +225,7 @@ func (pm *ProxyManager) StopProcesses() { wg.Wait() } -// Shutdown is called to shutdown all upstream processes -// when llama-swap is shutting down. +// Shutdown stops all processes managed by this ProxyManager func (pm *ProxyManager) Shutdown() { pm.Lock() defer pm.Unlock() diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 83cbd09..242e4a7 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -34,7 +34,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) { req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), modelName) } @@ -72,10 +72,9 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) { req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), requestedModel) - }) } @@ -115,7 +114,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) { req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), requestedModel) } @@ -158,14 +157,13 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) if w.Code != http.StatusOK { t.Errorf("Expected status OK, got %d for key %s", w.Code, key) } mu.Lock() - var response map[string]string assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) results[key] = response["responseMessage"] @@ -202,7 +200,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { w := httptest.NewRecorder() // Call the listModelsHandler - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) // Check the response status code assert.Equal(t, http.StatusOK, w.Code) @@ -292,7 +290,7 @@ func TestProxyManager_Shutdown(t *testing.T) { w := httptest.NewRecorder() // send a request to trigger the proxy to load ... this should hang waiting for start up - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusBadGateway, w.Code) assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") }(modelName) @@ -318,12 +316,12 @@ func TestProxyManager_Unload(t *testing.T) { reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) req = httptest.NewRequest("GET", "/unload", nil) w = httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, w.Body.String(), "OK") @@ -334,7 +332,6 @@ func TestProxyManager_Unload(t *testing.T) { // Test issue #61 `Listing the current list of models and the loaded model.` func TestProxyManager_RunningEndpoint(t *testing.T) { - // Shared configuration config := AddDefaultGroupToConfig(Config{ HealthCheckTimeout: 15, @@ -360,7 +357,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { t.Run("no models loaded", func(t *testing.T) { req := httptest.NewRequest("GET", "/running", nil) w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) @@ -378,13 +375,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) { reqBody := `{"model":"model1"}` req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) // Simulate browser call for the `/running` endpoint. req = httptest.NewRequest("GET", "/running", nil) w = httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) var response RunningResponse assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) @@ -436,7 +433,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req.Header.Set("Content-Type", w.FormDataContentType()) rec := httptest.NewRecorder() - proxy.HandlerFunc(rec, req) + proxy.ServeHTTP(rec, req) // Verify the response assert.Equal(t, http.StatusOK, rec.Code) @@ -473,7 +470,7 @@ func TestProxyManager_UseModelName(t *testing.T) { req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) assert.Contains(t, w.Body.String(), upstreamModelName) }) @@ -500,7 +497,7 @@ func TestProxyManager_UseModelName(t *testing.T) { req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req.Header.Set("Content-Type", w.FormDataContentType()) rec := httptest.NewRecorder() - proxy.HandlerFunc(rec, req) + proxy.ServeHTTP(rec, req) // Verify the response assert.Equal(t, http.StatusOK, rec.Code) @@ -568,7 +565,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) { } w := httptest.NewRecorder() - proxy.ginEngine.ServeHTTP(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, tt.expectedStatus, w.Code) @@ -592,7 +589,7 @@ func TestProxyManager_Upstream(t *testing.T) { defer proxy.StopProcesses() req := httptest.NewRequest("GET", "/upstream/model1/test", nil) rec := httptest.NewRecorder() - proxy.HandlerFunc(rec, req) + proxy.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "model1", rec.Body.String()) } @@ -613,7 +610,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) { req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) w := httptest.NewRecorder() - proxy.HandlerFunc(w, req) + proxy.ServeHTTP(w, req) assert.Equal(t, http.StatusOK, w.Code) var response map[string]string assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))