Files
llama-swap/proxy/proxymanager.go
Benson Wong 2dc0ca0663 improve llama-swap upstream process recovery and restarts (#155)
Refactor internal upstream process life cycle management to recover better from unexpected situations. With this change llama-swap should never need to be restarted due to a crashed upstream child process.  The `StateFailed` state was removed in favour of always trying to start/restart a process.
2025-06-05 16:24:55 -07:00

567 lines
16 KiB
Go

package proxy
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
PROFILE_SPLIT_CHAR = ":"
)
type ProxyManager struct {
sync.Mutex
config Config
ginEngine *gin.Engine
// logging
proxyLogger *LogMonitor
upstreamLogger *LogMonitor
muxLogger *LogMonitor
processGroups map[string]*ProcessGroup
}
func New(config Config) *ProxyManager {
// set up loggers
stdoutLogger := NewLogMonitorWriter(os.Stdout)
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
proxyLogger := NewLogMonitorWriter(stdoutLogger)
if config.LogRequests {
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
}
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
case "debug":
proxyLogger.SetLogLevel(LevelDebug)
upstreamLogger.SetLogLevel(LevelDebug)
case "info":
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
case "warn":
proxyLogger.SetLogLevel(LevelWarn)
upstreamLogger.SetLogLevel(LevelWarn)
case "error":
proxyLogger.SetLogLevel(LevelError)
upstreamLogger.SetLogLevel(LevelError)
default:
proxyLogger.SetLogLevel(LevelInfo)
upstreamLogger.SetLogLevel(LevelInfo)
}
pm := &ProxyManager{
config: config,
ginEngine: gin.New(),
proxyLogger: proxyLogger,
muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger,
processGroups: make(map[string]*ProcessGroup),
}
// create the process groups
for groupID := range config.Groups {
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
pm.processGroups[groupID] = processGroup
}
pm.setupGinEngine()
return pm
}
func (pm *ProxyManager) setupGinEngine() {
pm.ginEngine.Use(func(c *gin.Context) {
// Start timer
start := time.Now()
// capture these because /upstream/:model rewrites them in c.Next()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Process request
c.Next()
// Stop timer
duration := time.Since(start)
statusCode := c.Writer.Status()
bodySize := c.Writer.Size()
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
clientIP,
method,
path,
c.Request.Proto,
statusCode,
bodySize,
c.Request.UserAgent(),
duration,
)
})
// see: issue: #81, #77 and #42 for CORS issues
// respond with permissive OPTIONS for any endpoint
pm.ginEngine.Use(func(c *gin.Context) {
if c.Request.Method == "OPTIONS" {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
// allow whatever the client requested by default
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
c.Header("Access-Control-Allow-Headers", sanitized)
} else {
c.Header(
"Access-Control-Allow-Headers",
"Content-Type, Authorization, Accept, X-Requested-With",
)
}
c.Header("Access-Control-Max-Age", "86400")
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
})
// Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
// Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
// Support embeddings
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
// in proxymanager_loghandlers.go
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/", func(c *gin.Context) {
// Set the Content-Type header to text/html
c.Header("Content-Type", "text/html")
// Write the embedded HTML content to the response
htmlData, err := getHTMLFile("index.html")
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
_, err = c.Writer.Write(htmlData)
if err != nil {
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
return
}
})
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
if data, err := getHTMLFile("favicon.ico"); err == nil {
c.Data(http.StatusOK, "image/x-icon", data)
} else {
c.String(http.StatusInternalServerError, err.Error())
}
})
// Disable console color for testing
gin.DisableConsoleColor()
}
// ServeHTTP implements http.Handler interface
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
pm.ginEngine.ServeHTTP(w, r)
}
// StopProcesses acquires a lock and stops all running upstream processes.
// This is the public method safe for concurrent calls.
// Unlike Shutdown, this method only stops the processes but doesn't perform
// a complete shutdown, allowing for process replacement without full termination.
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
pm.Lock()
defer pm.Unlock()
// stop Processes in parallel
var wg sync.WaitGroup
for _, processGroup := range pm.processGroups {
wg.Add(1)
go func(processGroup *ProcessGroup) {
defer wg.Done()
processGroup.StopProcesses(strategy)
}(processGroup)
}
wg.Wait()
}
// Shutdown stops all processes managed by this ProxyManager
func (pm *ProxyManager) Shutdown() {
pm.Lock()
defer pm.Unlock()
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
var wg sync.WaitGroup
// Send shutdown signal to all process in groups
for _, processGroup := range pm.processGroups {
wg.Add(1)
go func(processGroup *ProcessGroup) {
defer wg.Done()
processGroup.Shutdown()
}(processGroup)
}
wg.Wait()
}
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
// de-alias the real model name and get a real one
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
}
if processGroup.exclusive {
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
for groupId, otherGroup := range pm.processGroups {
if groupId != processGroup.id && !otherGroup.persistent {
otherGroup.StopProcesses(StopWaitForInflightRequest)
}
}
}
return processGroup, realModelName, nil
}
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{}
for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
data = append(data, map[string]interface{}{
"id": id,
"object": "model",
"created": time.Now().Unix(),
"owned_by": "llama-swap",
})
}
// Set the Content-Type header to application/json
c.Header("Content-Type", "application/json")
if origin := c.Request.Header.Get("Origin"); origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
}
// Encode the data as JSON and write it to the response writer
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"object": "list", "data": data}); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
return
}
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
requestedModel := c.Param("model_id")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
return
}
processGroup, _, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
}
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
var html strings.Builder
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><a href=\"/unload\">Unload all models</a><ul>")
// Extract keys and sort them
var modelIDs []string
for modelID, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
modelIDs = append(modelIDs, modelID)
}
sort.Strings(modelIDs)
// Iterate over sorted keys
for _, modelID := range modelIDs {
// Get process state
processGroup := pm.findGroupByModelName(modelID)
var state string
if processGroup != nil {
process := processGroup.processes[modelID]
if process != nil {
var stateStr string
switch process.CurrentState() {
case StateReady:
stateStr = "Ready"
case StateStarting:
stateStr = "Starting"
case StateStopping:
stateStr = "Stopping"
case StateShutdown:
stateStr = "Shutdown"
case StateStopped:
stateStr = "Stopped"
default:
stateStr = "Unknown"
}
state = stateStr
}
}
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a> - %s</li>", modelID, modelID, state))
}
html.WriteString("</ul></body></html>")
c.Header("Content-Type", "text/html")
c.String(http.StatusOK, html.String())
}
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
return
}
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
return
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
return
}
// Get model parameter from the form
requestedModel := c.Request.FormValue("model")
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
return
}
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" {
// # issue #69 allow custom model names to be sent to upstream
useModelName := pm.config.Models[realModelName].UseModelName
if useModelName != "" {
fieldValue = useModelName
} else {
fieldValue = requestedModel
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
return
}
if _, err = field.Write([]byte(fieldValue)); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
return
}
}
}
// Copy all files from the original request
for key, fileHeaders := range c.Request.MultipartForm.File {
for _, fileHeader := range fileHeaders {
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
return
}
file, err := fileHeader.Open()
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
return
}
if _, err = io.Copy(formFile, file); err != nil {
file.Close()
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
return
}
file.Close()
}
}
// Close the multipart writer to finalize the form
if err := multipartWriter.Close(); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
return
}
// Create a new request with the reconstructed form data
modifiedReq, err := http.NewRequestWithContext(
c.Request.Context(),
c.Request.Method,
c.Request.URL.String(),
&requestBuffer,
)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
return
}
// Copy the headers from the original request
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
}
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
acceptHeader := c.GetHeader("Accept")
if strings.Contains(acceptHeader, "application/json") {
c.JSON(statusCode, gin.H{"error": message})
} else {
c.String(statusCode, message)
}
}
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
pm.StopProcesses(StopImmediately)
c.String(http.StatusOK, "OK")
}
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
context.Header("Content-Type", "application/json")
runningProcesses := make([]gin.H, 0) // Default to an empty response.
for _, processGroup := range pm.processGroups {
for _, process := range processGroup.processes {
if process.CurrentState() == StateReady {
runningProcesses = append(runningProcesses, gin.H{
"model": process.ID,
"state": process.state,
})
}
}
}
// Put the results under the `running` key.
response := gin.H{
"running": runningProcesses,
}
context.JSON(http.StatusOK, response) // Always return 200 OK
}
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
for _, group := range pm.processGroups {
if group.HasMember(modelName) {
return group
}
}
return nil
}