Add stopCmd for custom stopping instructions (#136)
Allow configuration of how a model is stopped before swapping. Setting `cmdStop` in the configuration will override the default behaviour and enables better integration with other process/container managers like docker or podman.
This commit is contained in:
@@ -129,6 +129,10 @@ models:
|
|||||||
ghcr.io/ggerganov/llama.cpp:server
|
ghcr.io/ggerganov/llama.cpp:server
|
||||||
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
|
||||||
|
|
||||||
|
# use a custom command to stop the model when swapping. By default
|
||||||
|
# this is SIGTERM on POSIX systems, and taskkill on Windows systems
|
||||||
|
cmdStop: docker stop dockertest
|
||||||
|
|
||||||
# Groups provide advanced controls over model swapping behaviour. Using groups
|
# Groups provide advanced controls over model swapping behaviour. Using groups
|
||||||
# some models can be kept loaded indefinitely, while others are swapped out.
|
# some models can be kept loaded indefinitely, while others are swapped out.
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ const DEFAULT_GROUP_ID = "(default)"
|
|||||||
|
|
||||||
type ModelConfig struct {
|
type ModelConfig struct {
|
||||||
Cmd string `yaml:"cmd"`
|
Cmd string `yaml:"cmd"`
|
||||||
|
CmdStop string `yaml:"cmdStop"`
|
||||||
Proxy string `yaml:"proxy"`
|
Proxy string `yaml:"proxy"`
|
||||||
Aliases []string `yaml:"aliases"`
|
Aliases []string `yaml:"aliases"`
|
||||||
Env []string `yaml:"env"`
|
Env []string `yaml:"env"`
|
||||||
@@ -135,7 +136,6 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// iterate over the models and replace any ${PORT} with the next available port
|
|
||||||
// Get and sort all model IDs first, makes testing more consistent
|
// Get and sort all model IDs first, makes testing more consistent
|
||||||
modelIds := make([]string, 0, len(config.Models))
|
modelIds := make([]string, 0, len(config.Models))
|
||||||
for modelId := range config.Models {
|
for modelId := range config.Models {
|
||||||
@@ -143,10 +143,10 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
}
|
}
|
||||||
sort.Strings(modelIds) // This guarantees stable iteration order
|
sort.Strings(modelIds) // This guarantees stable iteration order
|
||||||
|
|
||||||
// iterate over the sorted models
|
|
||||||
nextPort := config.StartPort
|
nextPort := config.StartPort
|
||||||
for _, modelId := range modelIds {
|
for _, modelId := range modelIds {
|
||||||
modelConfig := config.Models[modelId]
|
modelConfig := config.Models[modelId]
|
||||||
|
// iterate over the models and replace any ${PORT} with the next available port
|
||||||
if strings.Contains(modelConfig.Cmd, "${PORT}") {
|
if strings.Contains(modelConfig.Cmd, "${PORT}") {
|
||||||
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
|
modelConfig.Cmd = strings.ReplaceAll(modelConfig.Cmd, "${PORT}", strconv.Itoa(nextPort))
|
||||||
if modelConfig.Proxy == "" {
|
if modelConfig.Proxy == "" {
|
||||||
@@ -160,6 +160,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
|
return Config{}, fmt.Errorf("model %s requires a proxy value when not using automatic ${PORT}", modelId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config = AddDefaultGroupToConfig(config)
|
config = AddDefaultGroupToConfig(config)
|
||||||
// check that members are all unique in the groups
|
// check that members are all unique in the groups
|
||||||
memberUsage := make(map[string]string) // maps member to group it appears in
|
memberUsage := make(map[string]string) // maps member to group it appears in
|
||||||
|
|||||||
@@ -38,5 +38,4 @@ func TestConfig_SanitizeCommand(t *testing.T) {
|
|||||||
args, err = SanitizeCommand("")
|
args, err = SanitizeCommand("")
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, args)
|
assert.Nil(t, args)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -400,8 +401,38 @@ func (p *Process) stopCommand(sigtermTTL time.Duration) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := p.terminateProcess(); err != nil {
|
// if err := p.terminateProcess(); err != nil {
|
||||||
p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
|
// p.proxyLogger.Debugf("<%s> Process already terminated: %v (normal during shutdown)", p.ID, err)
|
||||||
|
// }
|
||||||
|
// the default cmdStop to taskkill /f /t /pid ${PID}
|
||||||
|
if runtime.GOOS == "windows" && strings.TrimSpace(p.config.CmdStop) == "" {
|
||||||
|
p.config.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.config.CmdStop != "" {
|
||||||
|
// replace ${PID} with the pid of the process
|
||||||
|
stopArgs, err := SanitizeCommand(strings.ReplaceAll(p.config.CmdStop, "${PID}", fmt.Sprintf("%d", p.cmd.Process.Pid)))
|
||||||
|
if err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to sanitize stop command: %v", p.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.proxyLogger.Debugf("<%s> Executing stop command: %s", p.ID, strings.Join(stopArgs, " "))
|
||||||
|
|
||||||
|
stopCmd := exec.Command(stopArgs[0], stopArgs[1:]...)
|
||||||
|
stopCmd.Stdout = p.processLogger
|
||||||
|
stopCmd.Stderr = p.processLogger
|
||||||
|
stopCmd.Env = p.config.Env
|
||||||
|
|
||||||
|
if err := stopCmd.Run(); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to exec stop command: %v", p.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := p.cmd.Process.Signal(syscall.SIGTERM); err != nil {
|
||||||
|
p.proxyLogger.Errorf("<%s> Failed to send SIGTERM to process: %v", p.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import "syscall"
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
return p.cmd.Process.Signal(syscall.SIGTERM)
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os/exec"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (p *Process) terminateProcess() error {
|
|
||||||
pid := fmt.Sprintf("%d", p.cmd.Process.Pid)
|
|
||||||
cmd := exec.Command("taskkill", "/f", "/t", "/pid", pid)
|
|
||||||
return cmd.Run()
|
|
||||||
}
|
|
||||||
@@ -449,3 +449,22 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
// the request should have been interrupted by SIGKILL
|
// the request should have been interrupted by SIGKILL
|
||||||
<-waitChan
|
<-waitChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcess_StopCmd(t *testing.T) {
|
||||||
|
config := getTestSimpleResponderConfig("test_stop_cmd")
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
config.CmdStop = "taskkill /f /t /pid ${PID}"
|
||||||
|
} else {
|
||||||
|
config.CmdStop = "kill -TERM ${PID}"
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("testStopCmd", 2, config, debugLogger, debugLogger)
|
||||||
|
defer process.Stop()
|
||||||
|
|
||||||
|
err := process.start()
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.Equal(t, process.CurrentState(), StateReady)
|
||||||
|
process.StopImmediately()
|
||||||
|
assert.Equal(t, process.CurrentState(), StateStopped)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user