* Make starting upstream process on-demand (#10) * Add automatic unload of model after TTL is reached * add `ttl` configuration parameter to models in seconds, default is 0 (never unload)
This commit is contained in:
3
Makefile
3
Makefile
@@ -10,6 +10,9 @@ clean:
|
|||||||
rm -rf $(BUILD_DIR)
|
rm -rf $(BUILD_DIR)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
|
go test -short -v ./proxy
|
||||||
|
|
||||||
|
test-all:
|
||||||
go test -v ./proxy
|
go test -v ./proxy
|
||||||
|
|
||||||
# Build OSX binary
|
# Build OSX binary
|
||||||
|
|||||||
@@ -39,6 +39,11 @@ models:
|
|||||||
# until the server is ready
|
# until the server is ready
|
||||||
checkEndpoint: /custom-endpoint
|
checkEndpoint: /custom-endpoint
|
||||||
|
|
||||||
|
# automatically unload the model after 10 seconds
|
||||||
|
# ttl values must be a value greater than 0
|
||||||
|
# default: 0 = never unload model
|
||||||
|
ttl: 5
|
||||||
|
|
||||||
"qwen":
|
"qwen":
|
||||||
# environment variables to pass to the command
|
# environment variables to pass to the command
|
||||||
env:
|
env:
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ models:
|
|||||||
# check this path for a HTTP 200 response for the server to be ready
|
# check this path for a HTTP 200 response for the server to be ready
|
||||||
checkEndpoint: /health
|
checkEndpoint: /health
|
||||||
|
|
||||||
|
# unload model after 5 seconds
|
||||||
|
ttl: 5
|
||||||
|
|
||||||
"qwen":
|
"qwen":
|
||||||
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
cmd: models/llama-server-osx --port 8999 -m models/qwen2.5-0.5b-instruct-q8_0.gguf
|
||||||
proxy: http://127.0.0.1:8999
|
proxy: http://127.0.0.1:8999
|
||||||
|
|||||||
@@ -20,7 +20,11 @@ func main() {
|
|||||||
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
// Set the header to text/plain
|
// Set the header to text/plain
|
||||||
w.Header().Set("Content-Type", "text/plain")
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
|
fmt.Fprintln(w, *responseMessage)
|
||||||
|
})
|
||||||
|
|
||||||
|
http.HandleFunc("/env", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain")
|
||||||
fmt.Fprintln(w, *responseMessage)
|
fmt.Fprintln(w, *responseMessage)
|
||||||
|
|
||||||
// Get environment variables
|
// Get environment variables
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ type ModelConfig struct {
|
|||||||
Aliases []string `yaml:"aliases"`
|
Aliases []string `yaml:"aliases"`
|
||||||
Env []string `yaml:"env"`
|
Env []string `yaml:"env"`
|
||||||
CheckEndpoint string `yaml:"checkEndpoint"`
|
CheckEndpoint string `yaml:"checkEndpoint"`
|
||||||
|
UnloadAfter int `yaml:"ttl"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
func (m *ModelConfig) SanitizedCommand() ([]string, error) {
|
||||||
|
|||||||
@@ -17,27 +17,33 @@ import (
|
|||||||
type Process struct {
|
type Process struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
ID string
|
ID string
|
||||||
config ModelConfig
|
config ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
logMonitor *LogMonitor
|
logMonitor *LogMonitor
|
||||||
|
healthCheckTimeout int
|
||||||
|
|
||||||
|
isRunning bool
|
||||||
|
lastRequestHandled time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewProcess(ID string, config ModelConfig, logMonitor *LogMonitor) *Process {
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
cmd: nil,
|
cmd: nil,
|
||||||
logMonitor: logMonitor,
|
logMonitor: logMonitor,
|
||||||
|
healthCheckTimeout: healthCheckTimeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) Start(healthCheckTimeout int) error {
|
// start the process and check it for errors
|
||||||
|
func (p *Process) start() error {
|
||||||
p.Lock()
|
p.Lock()
|
||||||
defer p.Unlock()
|
defer p.Unlock()
|
||||||
|
|
||||||
if p.cmd != nil {
|
if p.isRunning {
|
||||||
return fmt.Errorf("process already started")
|
return fmt.Errorf("process already running")
|
||||||
}
|
}
|
||||||
|
|
||||||
args, err := p.config.SanitizedCommand()
|
args, err := p.config.SanitizedCommand()
|
||||||
@@ -51,6 +57,8 @@ func (p *Process) Start(healthCheckTimeout int) error {
|
|||||||
p.cmd.Env = p.config.Env
|
p.cmd.Env = p.config.Env
|
||||||
|
|
||||||
err = p.cmd.Start()
|
err = p.cmd.Start()
|
||||||
|
p.isRunning = true
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -58,7 +66,8 @@ func (p *Process) Start(healthCheckTimeout int) error {
|
|||||||
// watch for the command to exit
|
// watch for the command to exit
|
||||||
cmdCtx, cancel := context.WithCancelCause(context.Background())
|
cmdCtx, cancel := context.WithCancelCause(context.Background())
|
||||||
|
|
||||||
// monitor the command's exit status
|
// monitor the command's exit status. Usually this happens if
|
||||||
|
// the process exited unexpectedly
|
||||||
go func() {
|
go func() {
|
||||||
err := p.cmd.Wait()
|
err := p.cmd.Wait()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -66,13 +75,37 @@ func (p *Process) Start(healthCheckTimeout int) error {
|
|||||||
} else {
|
} else {
|
||||||
cancel(nil)
|
cancel(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.isRunning = false
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// wait a bit for process to start before checking the health endpoint
|
||||||
|
time.Sleep(250 * time.Millisecond)
|
||||||
|
|
||||||
// wait for checkHealthEndpoint
|
// wait for checkHealthEndpoint
|
||||||
if err := p.checkHealthEndpoint(cmdCtx, healthCheckTimeout); err != nil {
|
if err := p.checkHealthEndpoint(cmdCtx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.config.UnloadAfter > 0 {
|
||||||
|
// start a goroutine to check every second if
|
||||||
|
// the process should be stopped
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
||||||
|
|
||||||
|
for {
|
||||||
|
<-ticker.C
|
||||||
|
if time.Since(p.lastRequestHandled) > maxDuration {
|
||||||
|
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter)
|
||||||
|
p.Stop()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -80,15 +113,20 @@ func (p *Process) Stop() {
|
|||||||
p.Lock()
|
p.Lock()
|
||||||
defer p.Unlock()
|
defer p.Unlock()
|
||||||
|
|
||||||
if p.cmd == nil {
|
if !p.isRunning {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||||
p.cmd.Process.Wait()
|
p.cmd.Process.Wait()
|
||||||
|
p.isRunning = false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout int) error {
|
func (p *Process) IsRunning() bool {
|
||||||
|
return p.isRunning
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Process) checkHealthEndpoint(cmdCtx context.Context) error {
|
||||||
if p.config.Proxy == "" {
|
if p.config.Proxy == "" {
|
||||||
return fmt.Errorf("no upstream available to check /health")
|
return fmt.Errorf("no upstream available to check /health")
|
||||||
}
|
}
|
||||||
@@ -105,7 +143,7 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout
|
|||||||
}
|
}
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
proxyTo := p.config.Proxy
|
||||||
maxDuration := time.Second * time.Duration(healthCheckTimeout)
|
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
||||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||||
@@ -115,7 +153,6 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout
|
|||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
time.Sleep(time.Second)
|
|
||||||
req, err := http.NewRequest("GET", healthURL, nil)
|
req, err := http.NewRequest("GET", healthURL, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -162,15 +199,22 @@ func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout
|
|||||||
if ttl < 0 {
|
if ttl < 0 {
|
||||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
if p.cmd == nil {
|
if !p.isRunning {
|
||||||
http.Error(w, "process not started", http.StatusInternalServerError)
|
if err := p.start(); err != nil {
|
||||||
return
|
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||||
|
http.Error(w, errstr, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
p.lastRequestHandled = time.Now()
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
proxyTo := p.config.Proxy
|
||||||
client := &http.Client{}
|
client := &http.Client{}
|
||||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
||||||
|
|||||||
@@ -2,14 +2,17 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Check if the binary exists
|
// Check if the binary exists
|
||||||
@@ -29,53 +32,40 @@ func getBinaryPath() string {
|
|||||||
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
return filepath.Join("..", "build", fmt.Sprintf("simple-responder_%s_%s", goos, goarch))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProcess_ProcessStartStop(t *testing.T) {
|
func getTestSimpleResponderConfig(expectedMessage string) ModelConfig {
|
||||||
// Define the range
|
// Define the range
|
||||||
min := 12000
|
min := 12000
|
||||||
max := 13000
|
max := 13000
|
||||||
|
|
||||||
// Generate a random number between 12000 and 13000
|
// Generate a random number between 12000 and 13000
|
||||||
randomPort := rand.Intn(max-min+1) + min
|
randomPort := rand.Intn(max-min+1) + min
|
||||||
|
|
||||||
binaryPath := getBinaryPath()
|
binaryPath := getBinaryPath()
|
||||||
|
|
||||||
// Create a log monitor
|
|
||||||
logMonitor := NewLogMonitor()
|
|
||||||
|
|
||||||
expectedMessage := "testing91931"
|
|
||||||
|
|
||||||
// Create a process configuration
|
// Create a process configuration
|
||||||
config := ModelConfig{
|
return ModelConfig{
|
||||||
Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, randomPort, expectedMessage),
|
Cmd: fmt.Sprintf("%s --port %d --respond '%s'", binaryPath, randomPort, expectedMessage),
|
||||||
Proxy: fmt.Sprintf("http://127.0.0.1:%d", randomPort),
|
Proxy: fmt.Sprintf("http://127.0.0.1:%d", randomPort),
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_AutomaticallyStartsUpstream(t *testing.T) {
|
||||||
|
logMonitor := NewLogMonitorWriter(io.Discard)
|
||||||
|
expectedMessage := "testing91931"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
|
||||||
// Create a process
|
// Create a process
|
||||||
process := NewProcess("test-process", config, logMonitor)
|
process := NewProcess("test-process", 5, config, logMonitor)
|
||||||
|
|
||||||
// Start the process
|
|
||||||
t.Logf("Starting %s on port %d", binaryPath, randomPort)
|
|
||||||
err := process.Start(5)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to start process: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a test request
|
|
||||||
req := httptest.NewRequest("GET", "/", nil)
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
// Proxy the request
|
// process is automatically started
|
||||||
|
assert.False(t, process.IsRunning())
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
|
assert.True(t, process.IsRunning())
|
||||||
|
|
||||||
// Check the response
|
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||||
if w.Code != http.StatusOK {
|
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||||
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !strings.Contains(w.Body.String(), expectedMessage) {
|
|
||||||
t.Errorf("Expected body to contain '%s', got %q", expectedMessage, w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop the process
|
// Stop the process
|
||||||
process.Stop()
|
process.Stop()
|
||||||
@@ -86,8 +76,53 @@ func TestProcess_ProcessStartStop(t *testing.T) {
|
|||||||
// Proxy the request
|
// Proxy the request
|
||||||
process.ProxyRequest(w, req)
|
process.ProxyRequest(w, req)
|
||||||
|
|
||||||
// Check the response
|
// should have automatically started the process again
|
||||||
if w.Code == http.StatusInternalServerError {
|
if w.Code != http.StatusOK {
|
||||||
t.Errorf("Expected status code %d, got %d", http.StatusInternalServerError, w.Code)
|
t.Errorf("Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// test that the automatic start returns the expected error type
|
||||||
|
func TestProcess_BrokenModelConfig(t *testing.T) {
|
||||||
|
// Create a process configuration
|
||||||
|
config := ModelConfig{
|
||||||
|
Cmd: "nonexistant-command",
|
||||||
|
Proxy: "http://127.0.0.1:9913",
|
||||||
|
CheckEndpoint: "/health",
|
||||||
|
}
|
||||||
|
|
||||||
|
process := NewProcess("broken", 1, config, NewLogMonitor())
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.Equal(t, http.StatusInternalServerError, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "unable to start process")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProcess_UnloadAfterTTL(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping long auto unload TTL test")
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedMessage := "I_sense_imminent_danger"
|
||||||
|
config := getTestSimpleResponderConfig(expectedMessage)
|
||||||
|
assert.Equal(t, 0, config.UnloadAfter)
|
||||||
|
config.UnloadAfter = 3 // seconds
|
||||||
|
assert.Equal(t, 3, config.UnloadAfter)
|
||||||
|
|
||||||
|
process := NewProcess("ttl", 2, config, NewLogMonitorWriter(io.Discard))
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Proxy the request (auto start)
|
||||||
|
process.ProxyRequest(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code, "Expected status code %d, got %d", http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), expectedMessage)
|
||||||
|
|
||||||
|
assert.True(t, process.IsRunning())
|
||||||
|
|
||||||
|
// wait 5 seconds
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
|
||||||
|
assert.False(t, process.IsRunning())
|
||||||
|
}
|
||||||
|
|||||||
@@ -99,8 +99,8 @@ func (pm *ProxyManager) swapModel(requestedModel string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pm.currentProcess = NewProcess(modelID, modelConfig, pm.logMonitor)
|
pm.currentProcess = NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
|
||||||
return pm.currentProcess.Start(pm.config.HealthCheckTimeout)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
|
||||||
|
|||||||
Reference in New Issue
Block a user