diff --git a/README.md b/README.md index 2327c01..16ebe12 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,10 @@ models: ghcr.io/ggerganov/llama.cpp:server --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' + # use a custom command to stop the model when swapping. By default + # this is SIGTERM on POSIX systems, and taskkill on Windows systems + cmdStop: docker stop dockertest + # Groups provide advanced controls over model swapping behaviour. Using groups # some models can be kept loaded indefinitely, while others are swapped out. # diff --git a/proxy/config.go b/proxy/config.go index cd289fe..df9f7d0 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -17,6 +17,7 @@ const DEFAULT_GROUP_ID = "(default)" type ModelConfig struct { Cmd string `yaml:"cmd"` + CmdStop string `yaml:"cmdStop"` Proxy string `yaml:"proxy"` Aliases []string `yaml:"aliases"` Env []string `yaml:"env"` @@ -135,7 +136,6 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } } - // iterate over the models and replace any ${PORT} with the next available port // Get and sort all model IDs first, makes testing more consistent modelIds := make([]string, 0, len(config.Models)) for modelId := range config.Models { @@ -143,10 +143,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { } sort.Strings(modelIds) // This guarantees stable iteration order - // iterate over the sorted models nextPort := config.StartPort for _, modelId := range modelIds { modelConfig := config.Models[modelId] + // iterate over the models and replace any ${PORT} with the next available port if strings.Contains(modelConfig.Cmd, "${PORT}") { modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort)) if modelConfig.Proxy == "" { @@ -160,6 +160,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId) } } + config = AddDefaultGroupToConfig(config) // check that members are all unique in the groups memberUsage := make(map[string]string) // maps member to group it appears in diff --git a/proxy/config_windows_test.go b/proxy/config_windows_test.go index ec3c3d7..d301a06 100644 --- a/proxy/config_windows_test.go +++ b/proxy/config_windows_test.go @@ -38,5 +38,4 @@ func TestConfig_SanitizeCommand(t *testing.T) { args, err = SanitizeCommand("") assert.Error(t, err) assert.Nil(t, args) - } diff --git a/proxy/process.go b/proxy/process.go index df4df1e..2870139 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "os/exec" + "runtime" "strconv" "strings" "sync" @@ -400,8 +401,38 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) { return } - if err := p.terminateProcess(); err != nil { - p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err) + // if err := p.terminateProcess(); err != nil { + // p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err) + // } + // the default cmdStop to taskkill /f /t /pid ${PID} + if runtime.GOOS == "windows" && strings.TrimSpace(p.config.CmdStop) == "" { + p.config.CmdStop = "taskkill /f /t /pid ${PID}" + } + + if p.config.CmdStop != "" { + // replace ${PID} with the pid of the process + stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid))) + if err != nil { + p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err) + return + } + + p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " ")) + + stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...) + stopCmd.Stdout = p.processLogger + stopCmd.Stderr = p.processLogger + stopCmd.Env = p.config.Env + + if err := stopCmd.Run(); err != nil { + p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err) + return + } + } else { + if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil { + p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err) + return + } } select { diff --git a/proxy/process_stop.go b/proxy/process_stop.go deleted file mode 100644 index 47ed697..0000000 --- a/proxy/process_stop.go +++ /dev/null @@ -1,9 +0,0 @@ -//go:build !windows - -package proxy - -import "syscall" - -func (p *Process) terminateProcess() error { - return p.cmd.Process.Signal(syscall.SIGTERM) -} diff --git a/proxy/process_stop_windows.go b/proxy/process_stop_windows.go deleted file mode 100644 index 245f5bf..0000000 --- a/proxy/process_stop_windows.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build windows - -package proxy - -import ( - "fmt" - "os/exec" -) - -func (p *Process) terminateProcess() error { - pid := fmt.Sprintf("%d", p.cmd.Process.Pid) - cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid) - return cmd.Run() -} diff --git a/proxy/process_test.go b/proxy/process_test.go index 885cbfd..6896803 100644 --- a/proxy/process_test.go +++ b/proxy/process_test.go @@ -449,3 +449,22 @@ func TestProcess_ForceStopWithKill(t *testing.T) { // the request should have been interrupted by SIGKILL <-waitChan } + +func TestProcess_StopCmd(t *testing.T) { + config := getTestSimpleResponderConfig("test_stop_cmd") + + if runtime.GOOS == "windows" { + config.CmdStop = "taskkill /f /t /pid ${PID}" + } else { + config.CmdStop = "kill -TERM ${PID}" + } + + process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger) + defer process.Stop() + + err := process.start() + assert.Nil(t, err) + assert.Equal(t, process.CurrentState(), StateReady) + process.StopImmediately() + assert.Equal(t, process.CurrentState(), StateStopped) +}