Refactor internal upstream process life cycle management to recover better from unexpected situations. With this change llama-swap should never need to be restarted due to a crashed upstream child process. The `StateFailed` state was removed in favour of always trying to start/restart a process.
567 lines
16 KiB
Go
567 lines
16 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"os"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
const (
|
|
PROFILE_SPLIT_CHAR = ":"
|
|
)
|
|
|
|
type ProxyManager struct {
|
|
sync.Mutex
|
|
|
|
config Config
|
|
ginEngine *gin.Engine
|
|
|
|
// logging
|
|
proxyLogger *LogMonitor
|
|
upstreamLogger *LogMonitor
|
|
muxLogger *LogMonitor
|
|
|
|
processGroups map[string]*ProcessGroup
|
|
}
|
|
|
|
func New(config Config) *ProxyManager {
|
|
// set up loggers
|
|
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
|
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
|
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
|
|
|
if config.LogRequests {
|
|
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
|
}
|
|
|
|
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
|
case "debug":
|
|
proxyLogger.SetLogLevel(LevelDebug)
|
|
upstreamLogger.SetLogLevel(LevelDebug)
|
|
case "info":
|
|
proxyLogger.SetLogLevel(LevelInfo)
|
|
upstreamLogger.SetLogLevel(LevelInfo)
|
|
case "warn":
|
|
proxyLogger.SetLogLevel(LevelWarn)
|
|
upstreamLogger.SetLogLevel(LevelWarn)
|
|
case "error":
|
|
proxyLogger.SetLogLevel(LevelError)
|
|
upstreamLogger.SetLogLevel(LevelError)
|
|
default:
|
|
proxyLogger.SetLogLevel(LevelInfo)
|
|
upstreamLogger.SetLogLevel(LevelInfo)
|
|
}
|
|
|
|
pm := &ProxyManager{
|
|
config: config,
|
|
ginEngine: gin.New(),
|
|
|
|
proxyLogger: proxyLogger,
|
|
muxLogger: stdoutLogger,
|
|
upstreamLogger: upstreamLogger,
|
|
|
|
processGroups: make(map[string]*ProcessGroup),
|
|
}
|
|
|
|
// create the process groups
|
|
for groupID := range config.Groups {
|
|
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
|
pm.processGroups[groupID] = processGroup
|
|
}
|
|
|
|
pm.setupGinEngine()
|
|
return pm
|
|
}
|
|
|
|
func (pm *ProxyManager) setupGinEngine() {
|
|
pm.ginEngine.Use(func(c *gin.Context) {
|
|
// Start timer
|
|
start := time.Now()
|
|
|
|
// capture these because /upstream/:model rewrites them in c.Next()
|
|
clientIP := c.ClientIP()
|
|
method := c.Request.Method
|
|
path := c.Request.URL.Path
|
|
|
|
// Process request
|
|
c.Next()
|
|
|
|
// Stop timer
|
|
duration := time.Since(start)
|
|
|
|
statusCode := c.Writer.Status()
|
|
bodySize := c.Writer.Size()
|
|
|
|
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
|
clientIP,
|
|
method,
|
|
path,
|
|
c.Request.Proto,
|
|
statusCode,
|
|
bodySize,
|
|
c.Request.UserAgent(),
|
|
duration,
|
|
)
|
|
})
|
|
|
|
// see: issue: #81, #77 and #42 for CORS issues
|
|
// respond with permissive OPTIONS for any endpoint
|
|
pm.ginEngine.Use(func(c *gin.Context) {
|
|
if c.Request.Method == "OPTIONS" {
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
|
|
|
// allow whatever the client requested by default
|
|
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
|
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
|
|
c.Header("Access-Control-Allow-Headers", sanitized)
|
|
} else {
|
|
c.Header(
|
|
"Access-Control-Allow-Headers",
|
|
"Content-Type, Authorization, Accept, X-Requested-With",
|
|
)
|
|
}
|
|
c.Header("Access-Control-Max-Age", "86400")
|
|
c.AbortWithStatus(http.StatusNoContent)
|
|
return
|
|
}
|
|
c.Next()
|
|
})
|
|
|
|
// Set up routes using the Gin engine
|
|
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
|
// Support legacy /v1/completions api, see issue #12
|
|
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
|
|
|
// Support embeddings
|
|
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
|
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
|
|
|
// Support audio/speech endpoint
|
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
|
|
|
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
|
|
|
// in proxymanager_loghandlers.go
|
|
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
|
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
|
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
|
|
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
|
pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE)
|
|
|
|
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
|
|
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
|
|
|
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
|
|
|
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
|
|
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
|
// Set the Content-Type header to text/html
|
|
c.Header("Content-Type", "text/html")
|
|
|
|
// Write the embedded HTML content to the response
|
|
htmlData, err := getHTMLFile("index.html")
|
|
if err != nil {
|
|
c.String(http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
_, err = c.Writer.Write(htmlData)
|
|
if err != nil {
|
|
c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err))
|
|
return
|
|
}
|
|
})
|
|
|
|
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
|
if data, err := getHTMLFile("favicon.ico"); err == nil {
|
|
c.Data(http.StatusOK, "image/x-icon", data)
|
|
} else {
|
|
c.String(http.StatusInternalServerError, err.Error())
|
|
}
|
|
})
|
|
|
|
// Disable console color for testing
|
|
gin.DisableConsoleColor()
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler interface
|
|
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
pm.ginEngine.ServeHTTP(w, r)
|
|
}
|
|
|
|
// StopProcesses acquires a lock and stops all running upstream processes.
|
|
// This is the public method safe for concurrent calls.
|
|
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
|
// a complete shutdown, allowing for process replacement without full termination.
|
|
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
|
|
// stop Processes in parallel
|
|
var wg sync.WaitGroup
|
|
for _, processGroup := range pm.processGroups {
|
|
wg.Add(1)
|
|
go func(processGroup *ProcessGroup) {
|
|
defer wg.Done()
|
|
processGroup.StopProcesses(strategy)
|
|
}(processGroup)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
// Shutdown stops all processes managed by this ProxyManager
|
|
func (pm *ProxyManager) Shutdown() {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
|
|
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
|
|
|
var wg sync.WaitGroup
|
|
// Send shutdown signal to all process in groups
|
|
for _, processGroup := range pm.processGroups {
|
|
wg.Add(1)
|
|
go func(processGroup *ProcessGroup) {
|
|
defer wg.Done()
|
|
processGroup.Shutdown()
|
|
}(processGroup)
|
|
}
|
|
wg.Wait()
|
|
}
|
|
|
|
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
|
// de-alias the real model name and get a real one
|
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
|
if !found {
|
|
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
|
}
|
|
|
|
processGroup := pm.findGroupByModelName(realModelName)
|
|
if processGroup == nil {
|
|
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
|
}
|
|
|
|
if processGroup.exclusive {
|
|
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
|
for groupId, otherGroup := range pm.processGroups {
|
|
if groupId != processGroup.id && !otherGroup.persistent {
|
|
otherGroup.StopProcesses(StopWaitForInflightRequest)
|
|
}
|
|
}
|
|
}
|
|
|
|
return processGroup, realModelName, nil
|
|
}
|
|
|
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|
data := []interface{}{}
|
|
for id, modelConfig := range pm.config.Models {
|
|
if modelConfig.Unlisted {
|
|
continue
|
|
}
|
|
|
|
data = append(data, map[string]interface{}{
|
|
"id": id,
|
|
"object": "model",
|
|
"created": time.Now().Unix(),
|
|
"owned_by": "llama-swap",
|
|
})
|
|
}
|
|
|
|
// Set the Content-Type header to application/json
|
|
c.Header("Content-Type", "application/json")
|
|
|
|
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
|
c.Header("Access-Control-Allow-Origin", origin)
|
|
}
|
|
|
|
// Encode the data as JSON and write it to the response writer
|
|
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"object": "list", "data": data}); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error()))
|
|
return
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|
requestedModel := c.Param("model_id")
|
|
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
|
return
|
|
}
|
|
|
|
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// rewrite the path
|
|
c.Request.URL.Path = c.Param("upstreamPath")
|
|
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
|
}
|
|
|
|
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
|
|
var html strings.Builder
|
|
|
|
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><a href=\"/unload\">Unload all models</a><ul>")
|
|
|
|
// Extract keys and sort them
|
|
var modelIDs []string
|
|
for modelID, modelConfig := range pm.config.Models {
|
|
if modelConfig.Unlisted {
|
|
continue
|
|
}
|
|
|
|
modelIDs = append(modelIDs, modelID)
|
|
}
|
|
sort.Strings(modelIDs)
|
|
|
|
// Iterate over sorted keys
|
|
for _, modelID := range modelIDs {
|
|
// Get process state
|
|
processGroup := pm.findGroupByModelName(modelID)
|
|
var state string
|
|
if processGroup != nil {
|
|
process := processGroup.processes[modelID]
|
|
if process != nil {
|
|
var stateStr string
|
|
switch process.CurrentState() {
|
|
case StateReady:
|
|
stateStr = "Ready"
|
|
case StateStarting:
|
|
stateStr = "Starting"
|
|
case StateStopping:
|
|
stateStr = "Stopping"
|
|
case StateShutdown:
|
|
stateStr = "Shutdown"
|
|
case StateStopped:
|
|
stateStr = "Stopped"
|
|
default:
|
|
stateStr = "Unknown"
|
|
}
|
|
state = stateStr
|
|
}
|
|
}
|
|
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a> - %s</li>", modelID, modelID, state))
|
|
}
|
|
html.WriteString("</ul></body></html>")
|
|
c.Header("Content-Type", "text/html")
|
|
c.String(http.StatusOK, html.String())
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
|
return
|
|
}
|
|
|
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
|
return
|
|
}
|
|
|
|
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// issue #69 allow custom model names to be sent to upstream
|
|
useModelName := pm.config.Models[realModelName].UseModelName
|
|
if useModelName != "" {
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
|
return
|
|
}
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
|
|
// dechunk it as we already have all the body bytes see issue #11
|
|
c.Request.Header.Del("transfer-encoding")
|
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
|
c.Request.ContentLength = int64(len(bodyBytes))
|
|
|
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|
// Parse multipart form
|
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// Get model parameter from the form
|
|
requestedModel := c.Request.FormValue("model")
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
|
return
|
|
}
|
|
|
|
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
// Create a new buffer for the reconstructed request
|
|
var requestBuffer bytes.Buffer
|
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
|
|
// Copy all form values
|
|
for key, values := range c.Request.MultipartForm.Value {
|
|
for _, value := range values {
|
|
fieldValue := value
|
|
// If this is the model field and we have a profile, use just the model name
|
|
if key == "model" {
|
|
// # issue #69 allow custom model names to be sent to upstream
|
|
useModelName := pm.config.Models[realModelName].UseModelName
|
|
|
|
if useModelName != "" {
|
|
fieldValue = useModelName
|
|
} else {
|
|
fieldValue = requestedModel
|
|
}
|
|
}
|
|
field, err := multipartWriter.CreateFormField(key)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
|
return
|
|
}
|
|
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Copy all files from the original request
|
|
for key, fileHeaders := range c.Request.MultipartForm.File {
|
|
for _, fileHeader := range fileHeaders {
|
|
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
|
return
|
|
}
|
|
|
|
file, err := fileHeader.Open()
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
|
return
|
|
}
|
|
|
|
if _, err = io.Copy(formFile, file); err != nil {
|
|
file.Close()
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
|
return
|
|
}
|
|
file.Close()
|
|
}
|
|
}
|
|
|
|
// Close the multipart writer to finalize the form
|
|
if err := multipartWriter.Close(); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
|
return
|
|
}
|
|
|
|
// Create a new request with the reconstructed form data
|
|
modifiedReq, err := http.NewRequestWithContext(
|
|
c.Request.Context(),
|
|
c.Request.Method,
|
|
c.Request.URL.String(),
|
|
&requestBuffer,
|
|
)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
|
return
|
|
}
|
|
|
|
// Copy the headers from the original request
|
|
modifiedReq.Header = c.Request.Header.Clone()
|
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
|
|
|
// set the content length of the body
|
|
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
|
|
|
// Use the modified request for proxying
|
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
|
acceptHeader := c.GetHeader("Accept")
|
|
|
|
if strings.Contains(acceptHeader, "application/json") {
|
|
c.JSON(statusCode, gin.H{"error": message})
|
|
} else {
|
|
c.String(statusCode, message)
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
|
pm.StopProcesses(StopImmediately)
|
|
c.String(http.StatusOK, "OK")
|
|
}
|
|
|
|
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|
context.Header("Content-Type", "application/json")
|
|
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
|
|
|
for _, processGroup := range pm.processGroups {
|
|
for _, process := range processGroup.processes {
|
|
if process.CurrentState() == StateReady {
|
|
runningProcesses = append(runningProcesses, gin.H{
|
|
"model": process.ID,
|
|
"state": process.state,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// Put the results under the `running` key.
|
|
response := gin.H{
|
|
"running": runningProcesses,
|
|
}
|
|
|
|
context.JSON(http.StatusOK, response) // Always return 200 OK
|
|
}
|
|
|
|
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
|
for _, group := range pm.processGroups {
|
|
if group.HasMember(modelName) {
|
|
return group
|
|
}
|
|
}
|
|
return nil
|
|
}
|