Add stopCmd for custom stopping instructions (#136)

Allow configuration of how a model is stopped before swapping. Setting `cmdStop` in the configuration will override the default behaviour and enables better integration with other process/container managers like docker or podman.
This commit is contained in:
Benson Wong
2025-05-16 13:48:42 -07:00
committed by GitHub
parent f9ee7156dc
commit a8b81f2799
7 changed files with 59 additions and 28 deletions

View File

@@ -17,6 +17,7 @@ const DEFAULT_GROUP_ID = "(default)"
type ModelConfig struct {
Cmd string `yaml:"cmd"`
CmdStop string `yaml:"cmdStop"`
Proxy string `yaml:"proxy"`
Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"`
@@ -135,7 +136,6 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
}
// iterate over the models and replace any ${PORT} with the next available port
// Get and sort all model IDs first, makes testing more consistent
modelIds := make([]string, 0, len(config.Models))
for modelId := range config.Models {
@@ -143,10 +143,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
}
sort.Strings(modelIds) // This guarantees stable iteration order
// iterate over the sorted models
nextPort := config.StartPort
for _, modelId := range modelIds {
modelConfig := config.Models[modelId]
// iterate over the models and replace any ${PORT} with the next available port
if strings.Contains(modelConfig.Cmd, "${PORT}") {
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
if modelConfig.Proxy == "" {
@@ -160,6 +160,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
}
}
config = AddDefaultGroupToConfig(config)
// check that members are all unique in the groups
memberUsage := make(map[string]string) // maps member to group it appears in

View File

@@ -38,5 +38,4 @@ func TestConfig_SanitizeCommand(t *testing.T) {
args, err = SanitizeCommand("")
assert.Error(t, err)
assert.Nil(t, args)
}

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"net/url"
"os/exec"
"runtime"
"strconv"
"strings"
"sync"
@@ -400,8 +401,38 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
return
}
if err := p.terminateProcess(); err != nil {
p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
// if err := p.terminateProcess(); err != nil {
// p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
// }
// the default cmdStop to taskkill /f /t /pid ${PID}
if runtime.GOOS == "windows" && strings.TrimSpace(p.config.CmdStop) == "" {
p.config.CmdStop = "taskkill /f /t /pid ${PID}"
}
if p.config.CmdStop != "" {
// replace ${PID} with the pid of the process
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
if err != nil {
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
return
}
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
stopCmd.Stdout = p.processLogger
stopCmd.Stderr = p.processLogger
stopCmd.Env = p.config.Env
if err := stopCmd.Run(); err != nil {
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
return
}
} else {
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
return
}
}
select {

View File

@@ -1,9 +0,0 @@
//go:build !windows
package proxy
import "syscall"
func (p *Process) terminateProcess() error {
return p.cmd.Process.Signal(syscall.SIGTERM)
}

View File

@@ -1,14 +0,0 @@
//go:build windows
package proxy
import (
"fmt"
"os/exec"
)
func (p *Process) terminateProcess() error {
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
return cmd.Run()
}

View File

@@ -449,3 +449,22 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
// the request should have been interrupted by SIGKILL
<-waitChan
}
func TestProcess_StopCmd(t *testing.T) {
config := getTestSimpleResponderConfig("test_stop_cmd")
if runtime.GOOS == "windows" {
config.CmdStop = "taskkill /f /t /pid ${PID}"
} else {
config.CmdStop = "kill -TERM ${PID}"
}
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
defer process.Stop()
err := process.start()
assert.Nil(t, err)
assert.Equal(t, process.CurrentState(), StateReady)
process.StopImmediately()
assert.Equal(t, process.CurrentState(), StateStopped)
}