diff --git a/proxy/config.go b/proxy/config.go index 27e3773..37aac8f 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -25,22 +25,22 @@ type Config struct { HealthCheckTimeout int `yaml:"healthCheckTimeout"` } -func (c *Config) FindConfig(modelName string) (ModelConfig, bool) { +func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) { modelConfig, found := c.Models[modelName] if found { - return modelConfig, true + return modelConfig, modelName, true } // Search through aliases to find the right config - for _, config := range c.Models { + for actual, config := range c.Models { for _, alias := range config.Aliases { if alias == modelName { - return config, true + return config, actual, true } } } - return ModelConfig{}, false + return ModelConfig{}, "", false } func LoadConfig(path string) (*Config, error) { diff --git a/proxy/config_test.go b/proxy/config_test.go index a6e95e7..b45df37 100644 --- a/proxy/config_test.go +++ b/proxy/config_test.go @@ -91,18 +91,21 @@ func TestFindConfig(t *testing.T) { } // Test finding a model by its name - modelConfig, found := config.FindConfig("model1") + modelConfig, modelId, found := config.FindConfig("model1") assert.True(t, found) + assert.Equal(t, "model1", modelId) assert.Equal(t, config.Models["model1"], modelConfig) // Test finding a model by its alias - modelConfig, found = config.FindConfig("m1") + modelConfig, modelId, found = config.FindConfig("m1") assert.True(t, found) + assert.Equal(t, "model1", modelId) assert.Equal(t, config.Models["model1"], modelConfig) // Test finding a model that does not exist - modelConfig, found = config.FindConfig("model3") + modelConfig, modelId, found = config.FindConfig("model3") assert.False(t, found) + assert.Equal(t, "", modelId) assert.Equal(t, ModelConfig{}, modelConfig) } diff --git a/proxy/manager.go b/proxy/manager.go index 5330b14..f177ce8 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -2,31 +2,24 @@ package proxy import ( "bytes" - "context" "encoding/json" - "errors" "fmt" "io" "net/http" - "net/url" - "os/exec" - "strings" "sync" - "syscall" "time" ) type ProxyManager struct { sync.Mutex - config *Config - currentCmd *exec.Cmd - currentConfig ModelConfig - logMonitor *LogMonitor + config *Config + currentProcess *Process + logMonitor *LogMonitor } func New(config *Config) *ProxyManager { - return &ProxyManager{config: config, logMonitor: NewLogMonitor()} + return &ProxyManager{config: config, currentProcess: nil, logMonitor: NewLogMonitor()} } func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) { @@ -40,7 +33,12 @@ func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) { } else if r.URL.Path == "/logs" { pm.streamLogs(w, r) } else { - pm.proxyRequest(w, r) + if pm.currentProcess != nil { + pm.currentProcess.ProxyRequest(w, r) + } else { + http.Error(w, "no strategy to handle request", http.StatusBadRequest) + } + } } @@ -111,144 +109,22 @@ func (pm *ProxyManager) swapModel(requestedModel string) error { defer pm.Unlock() // find the model configuration matching requestedModel - modelConfig, found := pm.config.FindConfig(requestedModel) + modelConfig, modelID, found := pm.config.FindConfig(requestedModel) if !found { return fmt.Errorf("could not find configuration for %s", requestedModel) } - // no need to swap llama.cpp instances - if pm.currentConfig.Cmd == modelConfig.Cmd { - return nil - } - - // kill the current running one to swap it - if pm.currentCmd != nil { - pm.currentCmd.Process.Signal(syscall.SIGTERM) - - // wait for it to end - pm.currentCmd.Process.Wait() - } - - pm.currentConfig = modelConfig - - args, err := modelConfig.SanitizedCommand() - if err != nil { - return fmt.Errorf("unable to get sanitized command: %v", err) - } - cmd := exec.Command(args[0], args[1:]...) - - // logMonitor only writes to stdout - // so the upstream's stderr will go to os.Stdout - cmd.Stdout = pm.logMonitor - cmd.Stderr = pm.logMonitor - - cmd.Env = modelConfig.Env - - err = cmd.Start() - if err != nil { - return err - } - pm.currentCmd = cmd - - // watch for the command to exist - cmdCtx, cancel := context.WithCancelCause(context.Background()) - - // monitor the command's exist status - go func() { - err := cmd.Wait() - if err != nil { - cancel(fmt.Errorf("command [%s] %s", strings.Join(cmd.Args, " "), err.Error())) - } else { - cancel(nil) - } - }() - - // wait for checkHealthEndpoint - if err := pm.checkHealthEndpoint(cmdCtx); err != nil { - return err - } - - return nil -} - -func (pm *ProxyManager) checkHealthEndpoint(cmdCtx context.Context) error { - - if pm.currentConfig.Proxy == "" { - return fmt.Errorf("no upstream available to check /health") - } - - checkEndpoint := strings.TrimSpace(pm.currentConfig.CheckEndpoint) - - if checkEndpoint == "none" { - return nil - } - - // keep default behaviour - if checkEndpoint == "" { - checkEndpoint = "/health" - } - - proxyTo := pm.currentConfig.Proxy - maxDuration := time.Second * time.Duration(pm.config.HealthCheckTimeout) - healthURL, err := url.JoinPath(proxyTo, checkEndpoint) - if err != nil { - return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint) - } - - client := &http.Client{} - startTime := time.Now() - - for { - req, err := http.NewRequest("GET", healthURL, nil) - if err != nil { - return err - } - - ctx, cancel := context.WithTimeout(cmdCtx, time.Second) - defer cancel() - req = req.WithContext(ctx) - resp, err := client.Do(req) - - ttl := (maxDuration - time.Since(startTime)).Seconds() - - if err != nil { - // check if the context was cancelled - select { - case <-ctx.Done(): - err := context.Cause(ctx) - if !errors.Is(err, context.DeadlineExceeded) { - return err - } - default: - } - - // wait a bit longer for TCP connection issues - if strings.Contains(err.Error(), "connection refused") { - fmt.Fprintf(pm.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl) - - time.Sleep(5 * time.Second) - } else { - time.Sleep(time.Second) - } - - if ttl < 0 { - return fmt.Errorf("failed to check health from: %s", healthURL) - } - - continue - } - - defer resp.Body.Close() - if resp.StatusCode == http.StatusOK { + // do nothing as it's already the correct process + if pm.currentProcess != nil { + if pm.currentProcess.ID == modelID { return nil + } else { + pm.currentProcess.Stop() } - - if ttl < 0 { - return fmt.Errorf("failed to check health from: %s", healthURL) - } - - time.Sleep(time.Second) } + + pm.currentProcess = NewProcess(modelID, modelConfig, pm.logMonitor) + return pm.currentProcess.Start(pm.config.HealthCheckTimeout) } func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request) { @@ -274,55 +150,5 @@ func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request) } r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - pm.proxyRequest(w, r) -} - -func (pm *ProxyManager) proxyRequest(w http.ResponseWriter, r *http.Request) { - if pm.currentConfig.Proxy == "" { - http.Error(w, "No upstream proxy", http.StatusInternalServerError) - return - } - - proxyTo := pm.currentConfig.Proxy - - client := &http.Client{} - req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - req.Header = r.Header - resp, err := client.Do(req) - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - return - } - defer resp.Body.Close() - for k, vv := range resp.Header { - for _, v := range vv { - w.Header().Add(k, v) - } - } - w.WriteHeader(resp.StatusCode) - - // faster than io.Copy when streaming - buf := make([]byte, 32*1024) - for { - n, err := resp.Body.Read(buf) - if n > 0 { - if _, writeErr := w.Write(buf[:n]); writeErr != nil { - return - } - if flusher, ok := w.(http.Flusher); ok { - flusher.Flush() - } - } - if err == io.EOF { - break - } - if err != nil { - http.Error(w, err.Error(), http.StatusBadGateway) - return - } - } + pm.currentProcess.ProxyRequest(w, r) } diff --git a/proxy/process.go b/proxy/process.go new file mode 100644 index 0000000..0c85864 --- /dev/null +++ b/proxy/process.go @@ -0,0 +1,218 @@ +package proxy + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os/exec" + "strings" + "sync" + "syscall" + "time" +) + +type Process struct { + sync.Mutex + + ID string + config ModelConfig + cmd *exec.Cmd + logMonitor *LogMonitor +} + +func NewProcess(ID string, config ModelConfig, logMonitor *LogMonitor) *Process { + return &Process{ + ID: ID, + config: config, + cmd: nil, + logMonitor: logMonitor, + } +} + +func (p *Process) Start(healthCheckTimeout int) error { + p.Lock() + defer p.Unlock() + + if p.cmd != nil { + return fmt.Errorf("process already started") + } + + args, err := p.config.SanitizedCommand() + if err != nil { + return fmt.Errorf("unable to get sanitized command: %v", err) + } + + p.cmd = exec.Command(args[0], args[1:]...) + p.cmd.Stdout = p.logMonitor + p.cmd.Stderr = p.logMonitor + p.cmd.Env = p.config.Env + + err = p.cmd.Start() + if err != nil { + return err + } + + // watch for the command to exit + cmdCtx, cancel := context.WithCancelCause(context.Background()) + + // monitor the command's exit status + go func() { + err := p.cmd.Wait() + if err != nil { + cancel(fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())) + } else { + cancel(nil) + } + }() + + // wait for checkHealthEndpoint + if err := p.checkHealthEndpoint(cmdCtx, healthCheckTimeout); err != nil { + return err + } + + return nil +} + +func (p *Process) Stop() { + p.Lock() + defer p.Unlock() + + if p.cmd == nil { + return + } + + p.cmd.Process.Signal(syscall.SIGTERM) + p.cmd.Process.Wait() +} + +func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout int) error { + if p.config.Proxy == "" { + return fmt.Errorf("no upstream available to check /health") + } + + checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint) + + if checkEndpoint == "none" { + return nil + } + + // keep default behaviour + if checkEndpoint == "" { + checkEndpoint = "/health" + } + + proxyTo := p.config.Proxy + maxDuration := time.Second * time.Duration(healthCheckTimeout) + healthURL, err := url.JoinPath(proxyTo, checkEndpoint) + if err != nil { + return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint) + } + + client := &http.Client{} + startTime := time.Now() + + for { + req, err := http.NewRequest("GET", healthURL, nil) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(cmdCtx, time.Second) + defer cancel() + req = req.WithContext(ctx) + resp, err := client.Do(req) + + ttl := (maxDuration - time.Since(startTime)).Seconds() + + if err != nil { + // check if the context was cancelled + select { + case <-ctx.Done(): + err := context.Cause(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + return err + } + default: + } + + // wait a bit longer for TCP connection issues + if strings.Contains(err.Error(), "connection refused") { + fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl) + + time.Sleep(5 * time.Second) + } else { + time.Sleep(time.Second) + } + + if ttl < 0 { + return fmt.Errorf("failed to check health from: %s", healthURL) + } + + continue + } + + defer resp.Body.Close() + if resp.StatusCode == http.StatusOK { + return nil + } + + if ttl < 0 { + return fmt.Errorf("failed to check health from: %s", healthURL) + } + + time.Sleep(time.Second) + } +} + +// sends the request to the upstream process +func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { + if p.cmd == nil { + http.Error(w, "process not started", http.StatusInternalServerError) + return + } + + proxyTo := p.config.Proxy + client := &http.Client{} + req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + req.Header = r.Header + resp, err := client.Do(req) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + for k, vv := range resp.Header { + for _, v := range vv { + w.Header().Add(k, v) + } + } + w.WriteHeader(resp.StatusCode) + + // faster than io.Copy when streaming + buf := make([]byte, 32*1024) + for { + n, err := resp.Body.Read(buf) + if n > 0 { + if _, writeErr := w.Write(buf[:n]); writeErr != nil { + return + } + if flusher, ok := w.(http.Flusher); ok { + flusher.Flush() + } + } + if err == io.EOF { + break + } + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + } +}