add support for automatically unloading a model (#10) (#14)

* Make starting upstream process on-demand (#10)
* Add automatic unload of model after TTL is reached
* add `ttl` configuration parameter to models in seconds, default is 0 (never unload)
This commit is contained in:
Benson Wong
2024-11-19 16:32:51 -08:00
committed by GitHub
parent ba39ed4c18
commit 533162ce6a
8 changed files with 149 additions and 54 deletions

View File

@@ -10,6 +10,9 @@ clean:
rm -rf $(BUILD_DIR) rm -rf $(BUILD_DIR)
test: test:
go test -short -v ./proxy
test-all:
go test -v ./proxy go test -v ./proxy
# Build OSX binary # Build OSX binary

View File

@@ -39,6 +39,11 @@ models:
# until the server is ready # until the server is ready
checkEndpoint: /custom-endpoint checkEndpoint: /custom-endpoint
# automatically unload the model after 10 seconds
# ttl values must be a value greater than 0
# default: 0 = never unload model
ttl: 5
"qwen": "qwen":
# environment variables to pass to the command # environment variables to pass to the command
env: env:

View File

@@ -17,6 +17,9 @@ models:
# check this path for a HTTP 200 response for the server to be ready # check this path for a HTTP 200 response for the server to be ready
checkEndpoint: /health checkEndpoint: /health
# unload model after 5 seconds
ttl: 5
"qwen": "qwen":
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
proxy: http://127.0.0.1:8999 proxy: http://127.0.0.1:8999

View File

@@ -20,7 +20,11 @@ func main() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
// Set the header to text/plain // Set the header to text/plain
w.Header().Set("Content-Type", "text/plain") w.Header().Set("Content-Type", "text/plain")
fmt.Fprintln(w, *responseMessage)
})
http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain")
fmt.Fprintln(w, *responseMessage) fmt.Fprintln(w, *responseMessage)
// Get environment variables // Get environment variables

View File

@@ -14,6 +14,7 @@ type ModelConfig struct {
Aliases []string `yaml:"aliases"` Aliases []string `yaml:"aliases"`
Env []string `yaml:"env"` Env []string `yaml:"env"`
CheckEndpoint string `yaml:"checkEndpoint"` CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
} }
func (m *ModelConfig) SanitizedCommand() ([]string, error) { func (m *ModelConfig) SanitizedCommand() ([]string, error) {

View File

@@ -17,27 +17,33 @@ import (
type Process struct { type Process struct {
sync.Mutex sync.Mutex
ID string ID string
config ModelConfig config ModelConfig
cmd *exec.Cmd cmd *exec.Cmd
logMonitor *LogMonitor logMonitor *LogMonitor
healthCheckTimeout int
isRunning bool
lastRequestHandled time.Time
} }
func NewProcess(ID string, config ModelConfig, logMonitor *LogMonitor) *Process { func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
return &Process{ return &Process{
ID: ID, ID: ID,
config: config, config: config,
cmd: nil, cmd: nil,
logMonitor: logMonitor, logMonitor: logMonitor,
healthCheckTimeout: healthCheckTimeout,
} }
} }
func (p *Process) Start(healthCheckTimeout int) error { // start the process and check it for errors
func (p *Process) start() error {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if p.cmd != nil { if p.isRunning {
return fmt.Errorf("process already started") return fmt.Errorf("process already running")
} }
args, err := p.config.SanitizedCommand() args, err := p.config.SanitizedCommand()
@@ -51,6 +57,8 @@ func (p *Process) Start(healthCheckTimeout int) error {
p.cmd.Env = p.config.Env p.cmd.Env = p.config.Env
err = p.cmd.Start() err = p.cmd.Start()
p.isRunning = true
if err != nil { if err != nil {
return err return err
} }
@@ -58,7 +66,8 @@ func (p *Process) Start(healthCheckTimeout int) error {
// watch for the command to exit // watch for the command to exit
cmdCtx, cancel := context.WithCancelCause(context.Background()) cmdCtx, cancel := context.WithCancelCause(context.Background())
// monitor the command's exit status // monitor the command's exit status. Usually this happens if
// the process exited unexpectedly
go func() { go func() {
err := p.cmd.Wait() err := p.cmd.Wait()
if err != nil { if err != nil {
@@ -66,13 +75,37 @@ func (p *Process) Start(healthCheckTimeout int) error {
} else { } else {
cancel(nil) cancel(nil)
} }
p.isRunning = false
}() }()
// wait a bit for process to start before checking the health endpoint
time.Sleep(250 * time.Millisecond)
// wait for checkHealthEndpoint // wait for checkHealthEndpoint
if err := p.checkHealthEndpoint(cmdCtx, healthCheckTimeout); err != nil { if err := p.checkHealthEndpoint(cmdCtx); err != nil {
return err return err
} }
if p.config.UnloadAfter > 0 {
// start a goroutine to check every second if
// the process should be stopped
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
for {
<-ticker.C
if time.Since(p.lastRequestHandled) > maxDuration {
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter)
p.Stop()
return
}
}
}()
}
return nil return nil
} }
@@ -80,15 +113,20 @@ func (p *Process) Stop() {
p.Lock() p.Lock()
defer p.Unlock() defer p.Unlock()
if p.cmd == nil { if !p.isRunning {
return return
} }
p.cmd.Process.Signal(syscall.SIGTERM) p.cmd.Process.Signal(syscall.SIGTERM)
p.cmd.Process.Wait() p.cmd.Process.Wait()
p.isRunning = false
} }
func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout int) error { func (p *Process) IsRunning() bool {
return p.isRunning
}
func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error {
if p.config.Proxy == "" { if p.config.Proxy == "" {
return fmt.Errorf("no upstream available to check /health") return fmt.Errorf("no upstream available to check /health")
} }
@@ -105,7 +143,7 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout
} }
proxyTo := p.config.Proxy proxyTo := p.config.Proxy
maxDuration := time.Second * time.Duration(healthCheckTimeout) maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
healthURL, err := url.JoinPath(proxyTo, checkEndpoint) healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
if err != nil { if err != nil {
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint) return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
@@ -115,7 +153,6 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout
startTime := time.Now() startTime := time.Now()
for { for {
time.Sleep(time.Second)
req, err := http.NewRequest("GET", healthURL, nil) req, err := http.NewRequest("GET", healthURL, nil)
if err != nil { if err != nil {
return err return err
@@ -162,15 +199,22 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout
if ttl < 0 { if ttl < 0 {
return fmt.Errorf("failed to check health from: %s", healthURL) return fmt.Errorf("failed to check health from: %s", healthURL)
} }
time.Sleep(time.Second)
} }
} }
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
if p.cmd == nil { if !p.isRunning {
http.Error(w, "process not started", http.StatusInternalServerError) if err := p.start(); err != nil {
return errstr := fmt.Sprintf("unable to start process: %s", err)
http.Error(w, errstr, http.StatusInternalServerError)
return
}
} }
p.lastRequestHandled = time.Now()
proxyTo := p.config.Proxy proxyTo := p.config.Proxy
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body) req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)

View File

@@ -2,14 +2,17 @@ package proxy
import ( import (
"fmt" "fmt"
"io"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert"
) )
// Check if the binary exists // Check if the binary exists
@@ -29,53 +32,40 @@ func getBinaryPath() string {
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch)) return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
} }
func TestProcess_ProcessStartStop(t *testing.T) { func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
// Define the range // Define the range
min := 12000 min := 12000
max := 13000 max := 13000
// Generate a random number between 12000 and 13000 // Generate a random number between 12000 and 13000
randomPort := rand.Intn(max-min+1) + min randomPort := rand.Intn(max-min+1) + min
binaryPath := getBinaryPath() binaryPath := getBinaryPath()
// Create a log monitor
logMonitor := NewLogMonitor()
expectedMessage := "testing91931"
// Create a process configuration // Create a process configuration
config := ModelConfig{ return ModelConfig{
Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, randomPort, expectedMessage), Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, randomPort, expectedMessage),
Proxy: fmt.Sprintf("http://127.0.0.1:%d", randomPort), Proxy: fmt.Sprintf("http://127.0.0.1:%d", randomPort),
CheckEndpoint: "/health", CheckEndpoint: "/health",
} }
}
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
logMonitor := NewLogMonitorWriter(io.Discard)
expectedMessage := "testing91931"
config := getTestSimpleResponderConfig(expectedMessage)
// Create a process // Create a process
process := NewProcess("test-process", config, logMonitor) process := NewProcess("test-process", 5, config, logMonitor)
// Start the process
t.Logf("Starting %s on port %d", binaryPath, randomPort)
err := process.Start(5)
if err != nil {
t.Fatalf("Failed to start process: %v", err)
}
// Create a test request
req := httptest.NewRequest("GET", "/", nil) req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
// Proxy the request // process is automatically started
assert.False(t, process.IsRunning())
process.ProxyRequest(w, req) process.ProxyRequest(w, req)
assert.True(t, process.IsRunning())
// Check the response assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
if w.Code != http.StatusOK { assert.Contains(t, w.Body.String(), expectedMessage)
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
}
if !strings.Contains(w.Body.String(), expectedMessage) {
t.Errorf("Expected body to contain '%s', got %q", expectedMessage, w.Body.String())
}
// Stop the process // Stop the process
process.Stop() process.Stop()
@@ -86,8 +76,53 @@ func TestProcess_ProcessStartStop(t *testing.T) {
// Proxy the request // Proxy the request
process.ProxyRequest(w, req) process.ProxyRequest(w, req)
// Check the response // should have automatically started the process again
if w.Code == http.StatusInternalServerError { if w.Code != http.StatusOK {
t.Errorf("Expected status code %d, got %d", http.StatusInternalServerError, w.Code) t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
} }
} }
// test that the automatic start returns the expected error type
func TestProcess_BrokenModelConfig(t *testing.T) {
// Create a process configuration
config := ModelConfig{
Cmd: "nonexistant-command",
Proxy: "http://127.0.0.1:9913",
CheckEndpoint: "/health",
}
process := NewProcess("broken", 1, config, NewLogMonitor())
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Contains(t, w.Body.String(), "unable to start process")
}
func TestProcess_UnloadAfterTTL(t *testing.T) {
if testing.Short() {
t.Skip("skipping long auto unload TTL test")
}
expectedMessage := "I_sense_imminent_danger"
config := getTestSimpleResponderConfig(expectedMessage)
assert.Equal(t, 0, config.UnloadAfter)
config.UnloadAfter = 3 // seconds
assert.Equal(t, 3, config.UnloadAfter)
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
// Proxy the request (auto start)
process.ProxyRequest(w, req)
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), expectedMessage)
assert.True(t, process.IsRunning())
// wait 5 seconds
time.Sleep(5 * time.Second)
assert.False(t, process.IsRunning())
}

View File

@@ -99,8 +99,8 @@ func (pm *ProxyManager) swapModel(requestedModel string) error {
} }
} }
pm.currentProcess = NewProcess(modelID, modelConfig, pm.logMonitor) pm.currentProcess = NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
return pm.currentProcess.Start(pm.config.HealthCheckTimeout) return nil
} }
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) { func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {