Add stopCmd for custom stopping instructions (#136)
Allow configuration of how a model is stopped before swapping. Setting `cmdStop` in the configuration will override the default behaviour and enables better integration with other process/container managers like docker or podman.
This commit is contained in:
@@ -17,6 +17,7 @@ const DEFAULT_GROUP_ID = "(default)"
|
||||
|
||||
type ModelConfig struct {
|
||||
Cmd string `yaml:"cmd"`
|
||||
CmdStop string `yaml:"cmdStop"`
|
||||
Proxy string `yaml:"proxy"`
|
||||
Aliases []string `yaml:"aliases"`
|
||||
Env []string `yaml:"env"`
|
||||
@@ -135,7 +136,6 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
}
|
||||
|
||||
// iterate over the models and replace any ${PORT} with the next available port
|
||||
// Get and sort all model IDs first, makes testing more consistent
|
||||
modelIds := make([]string, 0, len(config.Models))
|
||||
for modelId := range config.Models {
|
||||
@@ -143,10 +143,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
}
|
||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||
|
||||
// iterate over the sorted models
|
||||
nextPort := config.StartPort
|
||||
for _, modelId := range modelIds {
|
||||
modelConfig := config.Models[modelId]
|
||||
// iterate over the models and replace any ${PORT} with the next available port
|
||||
if strings.Contains(modelConfig.Cmd, "${PORT}") {
|
||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
|
||||
if modelConfig.Proxy == "" {
|
||||
@@ -160,6 +160,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
|
||||
}
|
||||
}
|
||||
|
||||
config = AddDefaultGroupToConfig(config)
|
||||
// check that members are all unique in the groups
|
||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||
|
||||
@@ -38,5 +38,4 @@ func TestConfig_SanitizeCommand(t *testing.T) {
|
||||
args, err = SanitizeCommand("")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, args)
|
||||
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -400,8 +401,38 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := p.terminateProcess(); err != nil {
|
||||
p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
|
||||
// if err := p.terminateProcess(); err != nil {
|
||||
// p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
|
||||
// }
|
||||
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||
if runtime.GOOS == "windows" && strings.TrimSpace(p.config.CmdStop) == "" {
|
||||
p.config.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
}
|
||||
|
||||
if p.config.CmdStop != "" {
|
||||
// replace ${PID} with the pid of the process
|
||||
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||
if err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
|
||||
|
||||
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||
stopCmd.Stdout = p.processLogger
|
||||
stopCmd.Stderr = p.processLogger
|
||||
stopCmd.Env = p.config.Env
|
||||
|
||||
if err := stopCmd.Run(); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package proxy
|
||||
|
||||
import "syscall"
|
||||
|
||||
func (p *Process) terminateProcess() error {
|
||||
return p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
)
|
||||
|
||||
func (p *Process) terminateProcess() error {
|
||||
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
||||
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
||||
return cmd.Run()
|
||||
}
|
||||
@@ -449,3 +449,22 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
||||
// the request should have been interrupted by SIGKILL
|
||||
<-waitChan
|
||||
}
|
||||
|
||||
func TestProcess_StopCmd(t *testing.T) {
|
||||
config := getTestSimpleResponderConfig("test_stop_cmd")
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||
} else {
|
||||
config.CmdStop = "kill -TERM ${PID}"
|
||||
}
|
||||
|
||||
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
|
||||
defer process.Stop()
|
||||
|
||||
err := process.start()
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, process.CurrentState(), StateReady)
|
||||
process.StopImmediately()
|
||||
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user