Add config hot-reload (#106)

introduce --watch-config command line option to reload ProxyManager when configuration changes.
This commit is contained in:
Sam
2025-05-12 10:37:00 +10:00
committed by GitHub
parent 9548931258
commit bc652709a5
7 changed files with 196 additions and 72 deletions

View File

@@ -195,7 +195,7 @@ groups:
Docker is the quickest way to try out llama-swap: Docker is the quickest way to try out llama-swap:
``` ```shell
# use CPU inference # use CPU inference
$ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu $ docker run -it --rm -p 9292:8080 ghcr.io/mostlygeek/llama-swap:cpu
@@ -231,7 +231,7 @@ Specific versions are also available and are tagged with the llama-swap, archite
Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration. Beyond the demo you will likely want to run the containers with your downloaded models and custom configuration.
``` ```shell
$ docker run -it --rm --runtime nvidia -p 9292:8080 \ $ docker run -it --rm --runtime nvidia -p 9292:8080 \
-v /path/to/models:/models \ -v /path/to/models:/models \
-v /path/to/custom/config.yaml:/app/config.yaml \ -v /path/to/custom/config.yaml:/app/config.yaml \
@@ -246,7 +246,12 @@ Pre-built binaries are available for Linux, FreeBSD and Darwin (OSX). These are
1. Create a configuration file, see [config.example.yaml](config.example.yaml) 1. Create a configuration file, see [config.example.yaml](config.example.yaml)
1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture. 1. Download a [release](https://github.com/mostlygeek/llama-swap/releases) appropriate for your OS and architecture.
1. Run the binary with `llama-swap --config path/to/config.yaml` 1. Run the binary with `llama-swap --config path/to/config.yaml`.
Available flags:
- `--config`: Path to the configuration file (default: `config.yaml`).
- `--listen`: Address and port to listen on (default: `:8080`).
- `--version`: Show version information and exit.
- `--watch-config`: Automatically reload the configuration file when it changes. This will wait for in-flight requests to complete then stop all running models (default: `false`).
### Building from source ### Building from source
@@ -261,7 +266,7 @@ Open the `http://<host>/logs` with your browser to get a web interface with stre
Of course, CLI access is also supported: Of course, CLI access is also supported:
``` ```shell
# sends up to the last 10KB of logs # sends up to the last 10KB of logs
curl http://host/logs' curl http://host/logs'

1
go.mod
View File

@@ -3,6 +3,7 @@ module github.com/mostlygeek/llama-swap
go 1.23.0 go 1.23.0
require ( require (
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/stretchr/testify v1.9.0 github.com/stretchr/testify v1.9.0

26
go.sum
View File

@@ -9,12 +9,16 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU=
github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
@@ -23,6 +27,8 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
@@ -74,34 +80,18 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.23.0 h1:dIJU/v2J8Mdglj/8rJ6UUOM3Zc9zLZxVZwwxMooUSAI=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34= golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc= golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I=
golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4=
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8= golang.org/x/net v0.38.0 h1:vRMAPTMaeGqVhG5QyLJHqNDwecKTomGeqbnfZyKlBI8=
golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik= golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/text v0.15.0 h1:h1V/4gjBv8v9cjcR6+AR5+/cIYK5N/WAgiv4xlsEtAk=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY= golang.org/x/text v0.23.0 h1:D71I7dUrlY+VX0gQShAThNGHFxZ13dGLBHQLVl1mJlY=
golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=

View File

@@ -1,25 +1,34 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"log"
"net/http"
"os" "os"
"os/signal" "os/signal"
"path/filepath"
"syscall" "syscall"
"time"
"github.com/fsnotify/fsnotify"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/proxy" "github.com/mostlygeek/llama-swap/proxy"
) )
var version string = "0" var (
var commit string = "abcd1234" version string = "0"
var date = "unknown" commit string = "abcd1234"
date string = "unknown"
)
func main() { func main() {
// Define a command-line flag for the port // Define a command-line flag for the port
configPath := flag.String("config", "config.yaml", "config file name") configPath := flag.String("config", "config.yaml", "config file name")
listenStr := flag.String("listen", ":8080", "listen ip/port") listenStr := flag.String("listen", ":8080", "listen ip/port")
showVersion := flag.Bool("version", false, "show version of build") showVersion := flag.Bool("version", false, "show version of build")
watchConfig := flag.Bool("watch-config", false, "Automatically reload config file on change")
flag.Parse() // Parse the command-line flags flag.Parse() // Parse the command-line flags
@@ -46,18 +55,135 @@ func main() {
proxyManager := proxy.New(config) proxyManager := proxy.New(config)
// Setup channels for server management
reloadChan := make(chan *proxy.ProxyManager)
exitChan := make(chan struct{})
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Create server with initial handler
srv := &http.Server{
Addr: *listenStr,
Handler: proxyManager,
}
// Start server
fmt.Printf("llama-swap listening on %s\n", *listenStr)
go func() { go func() {
<-sigChan if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fmt.Println("Shutting down llama-swap") fmt.Printf("Fatal server error: %v\n", err)
proxyManager.Shutdown() close(exitChan)
os.Exit(0) }
}() }()
fmt.Println("llama-swap listening on " + *listenStr) // Handle config reloads and signals
if err := proxyManager.Run(*listenStr); err != nil { go func() {
fmt.Printf("Server error: %v\n", err) currentManager := proxyManager
os.Exit(1) for {
select {
case newManager := <-reloadChan:
log.Println("Config change detected, waiting for in-flight requests to complete...")
// Stop old manager processes gracefully (this waits for in-flight requests)
currentManager.StopProcesses()
// Now do a full shutdown to clear the process map
currentManager.Shutdown()
currentManager = newManager
srv.Handler = newManager
log.Println("Server handler updated with new config")
case sig := <-sigChan:
fmt.Printf("Received signal %v, shutting down...\n", sig)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentManager.Shutdown()
if err := srv.Shutdown(ctx); err != nil {
fmt.Printf("Server shutdown error: %v\n", err)
}
close(exitChan)
return
}
}
}()
// Start file watcher if requested
if *watchConfig {
absConfigPath, err := filepath.Abs(*configPath)
if err != nil {
log.Printf("Error getting absolute path for config: %v. File watching disabled.", err)
} else {
go watchConfigFileWithReload(absConfigPath, reloadChan)
}
}
// Wait for exit signal
<-exitChan
}
// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan.
func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) {
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Printf("Error creating file watcher: %v. File watching disabled.", err)
return
}
defer watcher.Close()
err = watcher.Add(configPath)
if err != nil {
log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err)
return
}
log.Printf("Watching config file for changes: %s", configPath)
var debounceTimer *time.Timer
debounceDuration := 2 * time.Second
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
// We only care about writes to the specific config file
if event.Name == configPath && event.Has(fsnotify.Write) {
// Reset or start the debounce timer
if debounceTimer != nil {
debounceTimer.Stop()
}
debounceTimer = time.AfterFunc(debounceDuration, func() {
log.Printf("Config file modified: %s, reloading...", event.Name)
// Try up to 3 times with exponential backoff
var newConfig proxy.Config
var err error
for retries := 0; retries < 3; retries++ {
// Load new configuration
newConfig, err = proxy.LoadConfig(configPath)
if err == nil {
break
}
log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err)
if retries < 2 {
time.Sleep(time.Duration(1<<retries) * time.Second)
}
}
if err != nil {
log.Printf("Failed to load new config after retries: %v", err)
return
}
// Create new ProxyManager with new config
newPM := proxy.New(newConfig)
reloadChan <- newPM
log.Println("Config reloaded successfully")
})
}
case err, ok := <-watcher.Errors:
if !ok {
log.Println("File watcher error channel closed.")
return
}
log.Printf("File watcher error: %v", err)
}
} }
} }

View File

@@ -185,7 +185,7 @@ func (p *Process) start() error {
return fmt.Errorf("start() failed: %v", err) return fmt.Errorf("start() failed: %v", err)
} }
// Capture the exit error for later signaling // Capture the exit error for later signalling
go func() { go func() {
exitErr := p.cmd.Wait() exitErr := p.cmd.Wait()
p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr) p.proxyLogger.Debugf("<%s> cmd.Wait() returned error: %v", p.ID, exitErr)
@@ -260,9 +260,9 @@ func (p *Process) start() error {
if strings.Contains(err.Error(), "connection refused") { if strings.Contains(err.Error(), "connection refused") {
endTime, _ := checkDeadline.Deadline() endTime, _ := checkDeadline.Deadline()
ttl := time.Until(endTime) ttl := time.Until(endTime)
p.proxyLogger.Infof("<%s> Connection refused on %s, giving up in %.0fs", p.ID, healthURL, ttl.Seconds()) p.proxyLogger.Debugf("<%s> Connection refused on %s, giving up in %.0fs (normal during startup)", p.ID, healthURL, ttl.Seconds())
} else { } else {
p.proxyLogger.Infof("<%s> Health check error on %s, %v", p.ID, healthURL, err) p.proxyLogger.Debugf("<%s> Health check error on %s, %v (normal during startup)", p.ID, healthURL, err)
} }
} }
} }
@@ -345,31 +345,33 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
defer cancelTimeout() defer cancelTimeout()
if p.cmd == nil || p.cmd.Process == nil { if p.cmd == nil || p.cmd.Process == nil {
p.proxyLogger.Warnf("<%s> cmd or cmd.Process is nil", p.ID) p.proxyLogger.Debugf("<%s> cmd or cmd.Process is nil (normal during config reload)", p.ID)
return return
} }
if err := p.terminateProcess(); err != nil { if err := p.terminateProcess(); err != nil {
p.proxyLogger.Infof("<%s> Failed to gracefully terminate process: %v", p.ID, err) p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
} }
select { select {
case <-sigtermTimeout.Done(): case <-sigtermTimeout.Done():
p.proxyLogger.Infof("<%s> Process timed out waiting to stop, sending KILL signal", p.ID) p.proxyLogger.Debugf("<%s> Process timed out waiting to stop, sending KILL signal (normal during shutdown)", p.ID)
p.cmd.Process.Kill() if err := p.cmd.Process.Kill(); err != nil {
p.proxyLogger.Errorf("<%s> Failed to kill process: %v", p.ID, err)
}
case err := <-p.cmdWaitChan: case err := <-p.cmdWaitChan:
// Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK // Note: in start(), p.cmdWaitChan also has a select { ... }. That should be OK
// because if we make it here then the cmd has been successfully running and made it // because if we make it here then the cmd has been successfully running and made it
// through the health check. There is a possibility that ithe cmd crashed after the health check // through the health check. There is a possibility that the cmd crashed after the health check
// succeeded but that's not a case llama-swap is handling for now. // succeeded but that's not a case llama-swap is handling for now.
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok { if errno, ok := err.(syscall.Errno); ok {
p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno) p.proxyLogger.Errorf("<%s> errno >> %v", p.ID, errno)
} else if exitError, ok := err.(*exec.ExitError); ok { } else if exitError, ok := err.(*exec.ExitError); ok {
if strings.Contains(exitError.String(), "signal: terminated") { if strings.Contains(exitError.String(), "signal: terminated") {
p.proxyLogger.Infof("<%s> Process stopped OK", p.ID) p.proxyLogger.Debugf("<%s> Process stopped OK", p.ID)
} else if strings.Contains(exitError.String(), "signal: interrupt") { } else if strings.Contains(exitError.String(), "signal: interrupt") {
p.proxyLogger.Infof("<%s> Process interrupted OK", p.ID) p.proxyLogger.Debugf("<%s> Process interrupted OK", p.ID)
} else { } else {
p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode()) p.proxyLogger.Warnf("<%s> ExitError >> %v, exit code: %d", p.ID, exitError, exitError.ExitCode())
} }

View File

@@ -82,6 +82,11 @@ func New(config Config) *ProxyManager {
pm.processGroups[groupID] = processGroup pm.processGroups[groupID] = processGroup
} }
pm.setupGinEngine()
return pm
}
func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.Use(func(c *gin.Context) { pm.ginEngine.Use(func(c *gin.Context) {
// Start timer // Start timer
start := time.Now() start := time.Now()
@@ -192,18 +197,17 @@ func New(config Config) *ProxyManager {
// Disable console color for testing // Disable console color for testing
gin.DisableConsoleColor() gin.DisableConsoleColor()
return pm
} }
func (pm *ProxyManager) Run(addr ...string) error { // ServeHTTP implements http.Handler interface
return pm.ginEngine.Run(addr...) func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r) pm.ginEngine.ServeHTTP(w, r)
} }
// StopProcesses acquires a lock and stops all running upstream processes.
// This is the public method safe for concurrent calls.
// Unlike Shutdown, this method only stops the processes but doesn't perform
// a complete shutdown, allowing for process replacement without full termination.
func (pm *ProxyManager) StopProcesses() { func (pm *ProxyManager) StopProcesses() {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()
@@ -221,8 +225,7 @@ func (pm *ProxyManager) StopProcesses() {
wg.Wait() wg.Wait()
} }
// Shutdown is called to shutdown all upstream processes // Shutdown stops all processes managed by this ProxyManager
// when llama-swap is shutting down.
func (pm *ProxyManager) Shutdown() { func (pm *ProxyManager) Shutdown() {
pm.Lock() pm.Lock()
defer pm.Unlock() defer pm.Unlock()

View File

@@ -34,7 +34,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName) assert.Contains(t, w.Body.String(), modelName)
} }
@@ -72,10 +72,9 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel) assert.Contains(t, w.Body.String(), requestedModel)
}) })
} }
@@ -115,7 +114,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), requestedModel) assert.Contains(t, w.Body.String(), requestedModel)
} }
@@ -158,14 +157,13 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
if w.Code != http.StatusOK { if w.Code != http.StatusOK {
t.Errorf("Expected status OK, got %d for key %s", w.Code, key) t.Errorf("Expected status OK, got %d for key %s", w.Code, key)
} }
mu.Lock() mu.Lock()
var response map[string]string var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
results[key] = response["responseMessage"] results[key] = response["responseMessage"]
@@ -202,7 +200,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
// Call the listModelsHandler // Call the listModelsHandler
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
// Check the response status code // Check the response status code
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -292,7 +290,7 @@ func TestProxyManager_Shutdown(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
// send a request to trigger the proxy to load ... this should hang waiting for start up // send a request to trigger the proxy to load ... this should hang waiting for start up
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusBadGateway, w.Code) assert.Equal(t, http.StatusBadGateway, w.Code)
assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown") assert.Contains(t, w.Body.String(), "health check interrupted due to shutdown")
}(modelName) }(modelName)
@@ -318,12 +316,12 @@ func TestProxyManager_Unload(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1") reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady) assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
req = httptest.NewRequest("GET", "/unload", nil) req = httptest.NewRequest("GET", "/unload", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK") assert.Equal(t, w.Body.String(), "OK")
@@ -334,7 +332,6 @@ func TestProxyManager_Unload(t *testing.T) {
// Test issue #61 `Listing the current list of models and the loaded model.` // Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) { func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration // Shared configuration
config := AddDefaultGroupToConfig(Config{ config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
@@ -360,7 +357,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
t.Run("no models loaded", func(t *testing.T) { t.Run("no models loaded", func(t *testing.T) {
req := httptest.NewRequest("GET", "/running", nil) req := httptest.NewRequest("GET", "/running", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
@@ -378,13 +375,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
reqBody := `{"model":"model1"}` reqBody := `{"model":"model1"}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
// Simulate browser call for the `/running` endpoint. // Simulate browser call for the `/running` endpoint.
req = httptest.NewRequest("GET", "/running", nil) req = httptest.NewRequest("GET", "/running", nil)
w = httptest.NewRecorder() w = httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
var response RunningResponse var response RunningResponse
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
@@ -436,7 +433,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType()) req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req) proxy.ServeHTTP(rec, req)
// Verify the response // Verify the response
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
@@ -473,7 +470,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName) assert.Contains(t, w.Body.String(), upstreamModelName)
}) })
@@ -500,7 +497,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType()) req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req) proxy.ServeHTTP(rec, req)
// Verify the response // Verify the response
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
@@ -568,7 +565,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ginEngine.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, tt.expectedStatus, w.Code) assert.Equal(t, tt.expectedStatus, w.Code)
@@ -592,7 +589,7 @@ func TestProxyManager_Upstream(t *testing.T) {
defer proxy.StopProcesses() defer proxy.StopProcesses()
req := httptest.NewRequest("GET", "/upstream/model1/test", nil) req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
rec := httptest.NewRecorder() rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req) proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String()) assert.Equal(t, "model1", rec.Body.String())
} }
@@ -613,7 +610,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.HandlerFunc(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))