Rewrite the swap behaviour so that in-flight requests block process swapping until they are completed. Additionally: - add tests for parallel requests with proxy.ProxyManager and proxy.Process - improve Process startup behaviour and simplified the code - stopping of processes are sent SIGTERM and have 5 seconds to terminate, before they are killed
332 lines
7.3 KiB
Go
332 lines
7.3 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"os/exec"
|
|
"strings"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
)
|
|
|
|
type ProcessState string
|
|
|
|
const (
|
|
StateStopped ProcessState = ProcessState("stopped")
|
|
StateReady ProcessState = ProcessState("ready")
|
|
StateFailed ProcessState = ProcessState("failed")
|
|
)
|
|
|
|
type Process struct {
|
|
sync.Mutex
|
|
|
|
ID string
|
|
config ModelConfig
|
|
cmd *exec.Cmd
|
|
logMonitor *LogMonitor
|
|
healthCheckTimeout int
|
|
|
|
lastRequestHandled time.Time
|
|
|
|
stateMutex sync.RWMutex
|
|
state ProcessState
|
|
|
|
inFlightRequests sync.WaitGroup
|
|
}
|
|
|
|
func NewProcess(ID string, healthCheckTimeout int, config ModelConfig, logMonitor *LogMonitor) *Process {
|
|
return &Process{
|
|
ID: ID,
|
|
config: config,
|
|
cmd: nil,
|
|
logMonitor: logMonitor,
|
|
healthCheckTimeout: healthCheckTimeout,
|
|
state: StateStopped,
|
|
}
|
|
}
|
|
|
|
// start the process and returns when it is ready
|
|
func (p *Process) start() error {
|
|
|
|
p.stateMutex.Lock()
|
|
defer p.stateMutex.Unlock()
|
|
|
|
if p.state == StateReady {
|
|
return nil
|
|
}
|
|
|
|
if p.state == StateFailed {
|
|
return fmt.Errorf("process is in a failed state and can not be restarted")
|
|
}
|
|
|
|
args, err := p.config.SanitizedCommand()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to get sanitized command: %v", err)
|
|
}
|
|
|
|
p.cmd = exec.Command(args[0], args[1:]...)
|
|
p.cmd.Stdout = p.logMonitor
|
|
p.cmd.Stderr = p.logMonitor
|
|
p.cmd.Env = p.config.Env
|
|
|
|
err = p.cmd.Start()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// One of three things can happen at this stage:
|
|
// 1. The command exits unexpectedly
|
|
// 2. The health check fails
|
|
// 3. The health check passes
|
|
//
|
|
// only in the third case will the process be considered Ready to accept
|
|
healthCheckContext, cancelHealthCheck := context.WithCancelCause(context.Background())
|
|
defer cancelHealthCheck(nil) // clean up
|
|
cmdWaitChan := make(chan error, 1)
|
|
healthCheckChan := make(chan error, 1)
|
|
|
|
go func() {
|
|
// possible cmd exits early
|
|
cmdWaitChan <- p.cmd.Wait()
|
|
}()
|
|
|
|
go func() {
|
|
<-time.After(250 * time.Millisecond) // give process a bit of time to start
|
|
healthCheckChan <- p.checkHealthEndpoint(healthCheckContext)
|
|
}()
|
|
|
|
select {
|
|
case err := <-cmdWaitChan:
|
|
p.state = StateFailed
|
|
if err != nil {
|
|
err = fmt.Errorf("command [%s] %s", strings.Join(p.cmd.Args, " "), err.Error())
|
|
} else {
|
|
err = fmt.Errorf("command [%s] exited unexpected", strings.Join(p.cmd.Args, " "))
|
|
}
|
|
cancelHealthCheck(err)
|
|
return err
|
|
case err := <-healthCheckChan:
|
|
if err != nil {
|
|
p.state = StateFailed
|
|
return err
|
|
}
|
|
}
|
|
|
|
if p.config.UnloadAfter > 0 {
|
|
// start a goroutine to check every second if
|
|
// the process should be stopped
|
|
go func() {
|
|
ticker := time.NewTicker(time.Second)
|
|
defer ticker.Stop()
|
|
maxDuration := time.Duration(p.config.UnloadAfter) * time.Second
|
|
|
|
for {
|
|
<-ticker.C
|
|
if time.Since(p.lastRequestHandled) > maxDuration {
|
|
fmt.Fprintf(p.logMonitor, "!!! Unloading model %s, TTL of %d reached.\n", p.ID, p.config.UnloadAfter)
|
|
p.Stop()
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
p.state = StateReady
|
|
return nil
|
|
}
|
|
|
|
func (p *Process) Stop() {
|
|
// wait for any inflight requests before proceeding
|
|
p.inFlightRequests.Wait()
|
|
|
|
p.stateMutex.Lock()
|
|
defer p.stateMutex.Unlock()
|
|
|
|
if p.state != StateReady {
|
|
return
|
|
}
|
|
|
|
if p.cmd == nil || p.cmd.Process == nil {
|
|
// this situation should never happen... but if it does just update the state
|
|
fmt.Fprintf(p.logMonitor, "!!! State is Ready but Command is nil.")
|
|
p.state = StateStopped
|
|
return
|
|
}
|
|
|
|
// Pretty sure this stopping code needs some work for windows and
|
|
// will be a source of pain in the future.
|
|
|
|
p.cmd.Process.Signal(syscall.SIGTERM)
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- p.cmd.Wait()
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
fmt.Printf("!!! process for %s timed out waiting to stop\n", p.ID)
|
|
p.cmd.Process.Kill()
|
|
p.cmd.Wait()
|
|
case err := <-done:
|
|
if err != nil {
|
|
if err.Error() != "wait: no child processes" {
|
|
// possible that simple-responder for testing is just not
|
|
// existing right, so suppress those errors.
|
|
fmt.Printf("!!! process for %s stopped with error > %v\n", p.ID, err)
|
|
}
|
|
}
|
|
}
|
|
p.state = StateStopped
|
|
}
|
|
|
|
func (p *Process) CurrentState() ProcessState {
|
|
p.stateMutex.RLock()
|
|
defer p.stateMutex.RUnlock()
|
|
return p.state
|
|
}
|
|
|
|
func (p *Process) checkHealthEndpoint(ctxFromStart context.Context) error {
|
|
if p.config.Proxy == "" {
|
|
return fmt.Errorf("no upstream available to check /health")
|
|
}
|
|
|
|
checkEndpoint := strings.TrimSpace(p.config.CheckEndpoint)
|
|
|
|
if checkEndpoint == "none" {
|
|
return nil
|
|
}
|
|
|
|
// keep default behaviour
|
|
if checkEndpoint == "" {
|
|
checkEndpoint = "/health"
|
|
}
|
|
|
|
proxyTo := p.config.Proxy
|
|
maxDuration := time.Second * time.Duration(p.healthCheckTimeout)
|
|
healthURL, err := url.JoinPath(proxyTo, checkEndpoint)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint)
|
|
}
|
|
|
|
client := &http.Client{}
|
|
startTime := time.Now()
|
|
|
|
for {
|
|
req, err := http.NewRequest("GET", healthURL, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(ctxFromStart, time.Second)
|
|
defer cancel()
|
|
req = req.WithContext(ctx)
|
|
resp, err := client.Do(req)
|
|
|
|
ttl := (maxDuration - time.Since(startTime)).Seconds()
|
|
|
|
if err != nil {
|
|
// check if the context was cancelled
|
|
select {
|
|
case <-ctx.Done():
|
|
err := context.Cause(ctx)
|
|
if !errors.Is(err, context.DeadlineExceeded) {
|
|
return err
|
|
}
|
|
default:
|
|
}
|
|
|
|
// wait a bit longer for TCP connection issues
|
|
if strings.Contains(err.Error(), "connection refused") {
|
|
fmt.Fprintf(p.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl)
|
|
time.Sleep(5 * time.Second)
|
|
} else {
|
|
time.Sleep(time.Second)
|
|
}
|
|
|
|
if ttl < 0 {
|
|
return fmt.Errorf("failed to check health from: %s", healthURL)
|
|
}
|
|
|
|
continue
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
if resp.StatusCode == http.StatusOK {
|
|
return nil
|
|
}
|
|
|
|
if ttl < 0 {
|
|
return fmt.Errorf("failed to check health from: %s", healthURL)
|
|
}
|
|
|
|
time.Sleep(time.Second)
|
|
}
|
|
}
|
|
|
|
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|
|
|
p.inFlightRequests.Add(1)
|
|
defer p.inFlightRequests.Done()
|
|
|
|
if p.CurrentState() != StateReady {
|
|
if err := p.start(); err != nil {
|
|
errstr := fmt.Sprintf("unable to start process: %s", err)
|
|
http.Error(w, errstr, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
p.lastRequestHandled = time.Now()
|
|
|
|
proxyTo := p.config.Proxy
|
|
client := &http.Client{}
|
|
req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
return
|
|
}
|
|
req.Header = r.Header.Clone()
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
for k, vv := range resp.Header {
|
|
for _, v := range vv {
|
|
w.Header().Add(k, v)
|
|
}
|
|
}
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
// faster than io.Copy when streaming
|
|
buf := make([]byte, 32*1024)
|
|
for {
|
|
n, err := resp.Body.Read(buf)
|
|
if n > 0 {
|
|
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
|
return
|
|
}
|
|
if flusher, ok := w.(http.Flusher); ok {
|
|
flusher.Flush()
|
|
}
|
|
}
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
if err != nil {
|
|
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
return
|
|
}
|
|
}
|
|
}
|