add proxy.Process to manage upstream proxy logic
This commit is contained in:
@@ -25,22 +25,22 @@ type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
}
|
||||
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, bool) {
|
||||
func (c *Config) FindConfig(modelName string) (ModelConfig, string, bool) {
|
||||
modelConfig, found := c.Models[modelName]
|
||||
if found {
|
||||
return modelConfig, true
|
||||
return modelConfig, modelName, true
|
||||
}
|
||||
|
||||
// Search through aliases to find the right config
|
||||
for _, config := range c.Models {
|
||||
for actual, config := range c.Models {
|
||||
for _, alias := range config.Aliases {
|
||||
if alias == modelName {
|
||||
return config, true
|
||||
return config, actual, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ModelConfig{}, false
|
||||
return ModelConfig{}, "", false
|
||||
}
|
||||
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
|
||||
@@ -91,18 +91,21 @@ func TestFindConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test finding a model by its name
|
||||
modelConfig, found := config.FindConfig("model1")
|
||||
modelConfig, modelId, found := config.FindConfig("model1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", modelId)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model by its alias
|
||||
modelConfig, found = config.FindConfig("m1")
|
||||
modelConfig, modelId, found = config.FindConfig("m1")
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "model1", modelId)
|
||||
assert.Equal(t, config.Models["model1"], modelConfig)
|
||||
|
||||
// Test finding a model that does not exist
|
||||
modelConfig, found = config.FindConfig("model3")
|
||||
modelConfig, modelId, found = config.FindConfig("model3")
|
||||
assert.False(t, found)
|
||||
assert.Equal(t, "", modelId)
|
||||
assert.Equal(t, ModelConfig{}, modelConfig)
|
||||
}
|
||||
|
||||
|
||||
214
proxy/manager.go
214
proxy/manager.go
@@ -2,31 +2,24 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
config *Config
|
||||
currentCmd *exec.Cmd
|
||||
currentConfig ModelConfig
|
||||
logMonitor *LogMonitor
|
||||
config *Config
|
||||
currentProcess *Process
|
||||
logMonitor *LogMonitor
|
||||
}
|
||||
|
||||
func New(config *Config) *ProxyManager {
|
||||
return &ProxyManager{config: config, logMonitor: NewLogMonitor()}
|
||||
return &ProxyManager{config: config, currentProcess: nil, logMonitor: NewLogMonitor()}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -40,7 +33,12 @@ func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) {
|
||||
} else if r.URL.Path == "/logs" {
|
||||
pm.streamLogs(w, r)
|
||||
} else {
|
||||
pm.proxyRequest(w, r)
|
||||
if pm.currentProcess != nil {
|
||||
pm.currentProcess.ProxyRequest(w, r)
|
||||
} else {
|
||||
http.Error(w, "no strategy to handle request", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,144 +109,22 @@ func (pm *ProxyManager) swapModel(requestedModel string) error {
|
||||
defer pm.Unlock()
|
||||
|
||||
// find the model configuration matching requestedModel
|
||||
modelConfig, found := pm.config.FindConfig(requestedModel)
|
||||
modelConfig, modelID, found := pm.config.FindConfig(requestedModel)
|
||||
if !found {
|
||||
return fmt.Errorf("could not find configuration for %s", requestedModel)
|
||||
}
|
||||
|
||||
// no need to swap llama.cpp instances
|
||||
if pm.currentConfig.Cmd == modelConfig.Cmd {
|
||||
return nil
|
||||
}
|
||||
|
||||
// kill the current running one to swap it
|
||||
if pm.currentCmd != nil {
|
||||
pm.currentCmd.Process.Signal(syscall.SIGTERM)
|
||||
|
||||
// wait for it to end
|
||||
pm.currentCmd.Process.Wait()
|
||||
}
|
||||
|
||||
pm.currentConfig = modelConfig
|
||||
|
||||
args, err := modelConfig.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
cmd := exec.Command(args[0], args[1:]...)
|
||||
|
||||
// logMonitor only writes to stdout
|
||||
// so the upstream's stderr will go to os.Stdout
|
||||
cmd.Stdout = pm.logMonitor
|
||||
cmd.Stderr = pm.logMonitor
|
||||
|
||||
cmd.Env = modelConfig.Env
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pm.currentCmd = cmd
|
||||
|
||||
// watch for the command to exist
|
||||
cmdCtx, cancel := context.WithCancelCause(context.Background())
|
||||
|
||||
// monitor the command's exist status
|
||||
go func() {
|
||||
err := cmd.Wait()
|
||||
if err != nil {
|
||||
cancel(fmt.Errorf("command [%s] %s", strings.Join(cmd.Args, " "), err.Error()))
|
||||
} else {
|
||||
cancel(nil)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for checkHealthEndpoint
|
||||
if err := pm.checkHealthEndpoint(cmdCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) checkHealthEndpoint(cmdCtx context.Context) error {
|
||||
|
||||
if pm.currentConfig.Proxy == "" {
|
||||
return fmt.Errorf("no upstream available to check /health")
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(pm.currentConfig.CheckEndpoint)
|
||||
|
||||
if checkEndpoint == "none" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
|
||||
proxyTo := pm.currentConfig.Proxy
|
||||
maxDuration := time.Second * time.Duration(pm.config.HealthCheckTimeout)
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(cmdCtx, time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
|
||||
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
||||
|
||||
if err != nil {
|
||||
// check if the context was cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := context.Cause(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// wait a bit longer for TCP connection issues
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
fmt.Fprintf(pm.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
// do nothing as it's already the correct process
|
||||
if pm.currentProcess != nil {
|
||||
if pm.currentProcess.ID == modelID {
|
||||
return nil
|
||||
} else {
|
||||
pm.currentProcess.Stop()
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
pm.currentProcess = NewProcess(modelID, modelConfig, pm.logMonitor)
|
||||
return pm.currentProcess.Start(pm.config.HealthCheckTimeout)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -274,55 +150,5 @@ func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
pm.proxyRequest(w, r)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if pm.currentConfig.Proxy == "" {
|
||||
http.Error(w, "No upstream proxy", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
proxyTo := pm.currentConfig.Proxy
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.Header = r.Header
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for k, vv := range resp.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// faster than io.Copy when streaming
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
||||
return
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
pm.currentProcess.ProxyRequest(w, r)
|
||||
}
|
||||
|
||||
218
proxy/process.go
Normal file
218
proxy/process.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Process struct {
|
||||
sync.Mutex
|
||||
|
||||
ID string
|
||||
config ModelConfig
|
||||
cmd *exec.Cmd
|
||||
logMonitor *LogMonitor
|
||||
}
|
||||
|
||||
func NewProcess(ID string, config ModelConfig, logMonitor *LogMonitor) *Process {
|
||||
return &Process{
|
||||
ID: ID,
|
||||
config: config,
|
||||
cmd: nil,
|
||||
logMonitor: logMonitor,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Process) Start(healthCheckTimeout int) error {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if p.cmd != nil {
|
||||
return fmt.Errorf("process already started")
|
||||
}
|
||||
|
||||
args, err := p.config.SanitizedCommand()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get sanitized command: %v", err)
|
||||
}
|
||||
|
||||
p.cmd = exec.Command(args[0], args[1:]...)
|
||||
p.cmd.Stdout = p.logMonitor
|
||||
p.cmd.Stderr = p.logMonitor
|
||||
p.cmd.Env = p.config.Env
|
||||
|
||||
err = p.cmd.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// watch for the command to exit
|
||||
cmdCtx, cancel := context.WithCancelCause(context.Background())
|
||||
|
||||
// monitor the command's exit status
|
||||
go func() {
|
||||
err := p.cmd.Wait()
|
||||
if err != nil {
|
||||
cancel(fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error()))
|
||||
} else {
|
||||
cancel(nil)
|
||||
}
|
||||
}()
|
||||
|
||||
// wait for checkHealthEndpoint
|
||||
if err := p.checkHealthEndpoint(cmdCtx, healthCheckTimeout); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Process) Stop() {
|
||||
p.Lock()
|
||||
defer p.Unlock()
|
||||
|
||||
if p.cmd == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.cmd.Process.Signal(syscall.SIGTERM)
|
||||
p.cmd.Process.Wait()
|
||||
}
|
||||
|
||||
func (p *Process) checkHealthEndpoint(cmdCtx context.Context, healthCheckTimeout int) error {
|
||||
if p.config.Proxy == "" {
|
||||
return fmt.Errorf("no upstream available to check /health")
|
||||
}
|
||||
|
||||
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
||||
|
||||
if checkEndpoint == "none" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// keep default behaviour
|
||||
if checkEndpoint == "" {
|
||||
checkEndpoint = "/health"
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
maxDuration := time.Second * time.Duration(healthCheckTimeout)
|
||||
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
startTime := time.Now()
|
||||
|
||||
for {
|
||||
req, err := http.NewRequest("GET", healthURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(cmdCtx, time.Second)
|
||||
defer cancel()
|
||||
req = req.WithContext(ctx)
|
||||
resp, err := client.Do(req)
|
||||
|
||||
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
||||
|
||||
if err != nil {
|
||||
// check if the context was cancelled
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
err := context.Cause(ctx)
|
||||
if !errors.Is(err, context.DeadlineExceeded) {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
// wait a bit longer for TCP connection issues
|
||||
if strings.Contains(err.Error(), "connection refused") {
|
||||
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
||||
|
||||
time.Sleep(5 * time.Second)
|
||||
} else {
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
return nil
|
||||
}
|
||||
|
||||
if ttl < 0 {
|
||||
return fmt.Errorf("failed to check health from: %s", healthURL)
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
// sends the request to the upstream process
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
if p.cmd == nil {
|
||||
http.Error(w, "process not started", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
proxyTo := p.config.Proxy
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
req.Header = r.Header
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
for k, vv := range resp.Header {
|
||||
for _, v := range vv {
|
||||
w.Header().Add(k, v)
|
||||
}
|
||||
}
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
// faster than io.Copy when streaming
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := resp.Body.Read(buf)
|
||||
if n > 0 {
|
||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
||||
return
|
||||
}
|
||||
if flusher, ok := w.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user