575 lines
16 KiB
Go
575 lines
16 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/tidwall/sjson"
|
|
)
|
|
|
|
const (
|
|
PROFILE_SPLIT_CHAR = ":"
|
|
)
|
|
|
|
type ProxyManager struct {
|
|
sync.Mutex
|
|
|
|
config Config
|
|
ginEngine *gin.Engine
|
|
|
|
// logging
|
|
proxyLogger *LogMonitor
|
|
upstreamLogger *LogMonitor
|
|
muxLogger *LogMonitor
|
|
|
|
metricsMonitor *MetricsMonitor
|
|
|
|
processGroups map[string]*ProcessGroup
|
|
|
|
// shutdown signaling
|
|
shutdownCtx context.Context
|
|
shutdownCancel context.CancelFunc
|
|
}
|
|
|
|
func New(config Config) *ProxyManager {
|
|
// set up loggers
|
|
stdoutLogger := NewLogMonitorWriter(os.Stdout)
|
|
upstreamLogger := NewLogMonitorWriter(stdoutLogger)
|
|
proxyLogger := NewLogMonitorWriter(stdoutLogger)
|
|
|
|
if config.LogRequests {
|
|
proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.")
|
|
}
|
|
|
|
switch strings.ToLower(strings.TrimSpace(config.LogLevel)) {
|
|
case "debug":
|
|
proxyLogger.SetLogLevel(LevelDebug)
|
|
upstreamLogger.SetLogLevel(LevelDebug)
|
|
case "info":
|
|
proxyLogger.SetLogLevel(LevelInfo)
|
|
upstreamLogger.SetLogLevel(LevelInfo)
|
|
case "warn":
|
|
proxyLogger.SetLogLevel(LevelWarn)
|
|
upstreamLogger.SetLogLevel(LevelWarn)
|
|
case "error":
|
|
proxyLogger.SetLogLevel(LevelError)
|
|
upstreamLogger.SetLogLevel(LevelError)
|
|
default:
|
|
proxyLogger.SetLogLevel(LevelInfo)
|
|
upstreamLogger.SetLogLevel(LevelInfo)
|
|
}
|
|
|
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
|
|
|
pm := &ProxyManager{
|
|
config: config,
|
|
ginEngine: gin.New(),
|
|
|
|
proxyLogger: proxyLogger,
|
|
muxLogger: stdoutLogger,
|
|
upstreamLogger: upstreamLogger,
|
|
|
|
metricsMonitor: NewMetricsMonitor(&config),
|
|
|
|
processGroups: make(map[string]*ProcessGroup),
|
|
|
|
shutdownCtx: shutdownCtx,
|
|
shutdownCancel: shutdownCancel,
|
|
}
|
|
|
|
// create the process groups
|
|
for groupID := range config.Groups {
|
|
processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger)
|
|
pm.processGroups[groupID] = processGroup
|
|
}
|
|
|
|
pm.setupGinEngine()
|
|
return pm
|
|
}
|
|
|
|
func (pm *ProxyManager) setupGinEngine() {
|
|
pm.ginEngine.Use(func(c *gin.Context) {
|
|
// Start timer
|
|
start := time.Now()
|
|
|
|
// capture these because /upstream/:model rewrites them in c.Next()
|
|
clientIP := c.ClientIP()
|
|
method := c.Request.Method
|
|
path := c.Request.URL.Path
|
|
|
|
// Process request
|
|
c.Next()
|
|
|
|
// Stop timer
|
|
duration := time.Since(start)
|
|
|
|
statusCode := c.Writer.Status()
|
|
bodySize := c.Writer.Size()
|
|
|
|
pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v",
|
|
clientIP,
|
|
method,
|
|
path,
|
|
c.Request.Proto,
|
|
statusCode,
|
|
bodySize,
|
|
c.Request.UserAgent(),
|
|
duration,
|
|
)
|
|
})
|
|
|
|
// see: issue: #81, #77 and #42 for CORS issues
|
|
// respond with permissive OPTIONS for any endpoint
|
|
pm.ginEngine.Use(func(c *gin.Context) {
|
|
if c.Request.Method == "OPTIONS" {
|
|
c.Header("Access-Control-Allow-Origin", "*")
|
|
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
|
|
|
// allow whatever the client requested by default
|
|
if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
|
sanitized := SanitizeAccessControlRequestHeaderValues(headers)
|
|
c.Header("Access-Control-Allow-Headers", sanitized)
|
|
} else {
|
|
c.Header(
|
|
"Access-Control-Allow-Headers",
|
|
"Content-Type, Authorization, Accept, X-Requested-With",
|
|
)
|
|
}
|
|
c.Header("Access-Control-Max-Age", "86400")
|
|
c.AbortWithStatus(http.StatusNoContent)
|
|
return
|
|
}
|
|
c.Next()
|
|
})
|
|
|
|
mm := MetricsMiddleware(pm)
|
|
|
|
// Set up routes using the Gin engine
|
|
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
|
|
// Support legacy /v1/completions api, see issue #12
|
|
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
|
|
|
|
// Support embeddings
|
|
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
|
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
|
|
|
// Support audio/speech endpoint
|
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
|
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
|
|
|
|
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
|
|
|
|
// in proxymanager_loghandlers.go
|
|
pm.ginEngine.GET("/logs", pm.sendLogsHandlers)
|
|
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
|
|
pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler)
|
|
|
|
/**
|
|
* User Interface Endpoints
|
|
*/
|
|
pm.ginEngine.GET("/", func(c *gin.Context) {
|
|
c.Redirect(http.StatusFound, "/ui")
|
|
})
|
|
|
|
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
|
c.Redirect(http.StatusFound, "/ui/models")
|
|
})
|
|
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
|
|
|
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
|
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
|
|
|
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
|
|
if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil {
|
|
c.Data(http.StatusOK, "image/x-icon", data)
|
|
} else {
|
|
c.String(http.StatusInternalServerError, err.Error())
|
|
}
|
|
})
|
|
|
|
reactFS, err := GetReactFS()
|
|
if err != nil {
|
|
pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err)
|
|
} else {
|
|
|
|
// serve files that exist under /ui/*
|
|
pm.ginEngine.StaticFS("/ui", reactFS)
|
|
|
|
// server SPA for UI under /ui/*
|
|
pm.ginEngine.NoRoute(func(c *gin.Context) {
|
|
if !strings.HasPrefix(c.Request.URL.Path, "/ui") {
|
|
c.AbortWithStatus(http.StatusNotFound)
|
|
return
|
|
}
|
|
|
|
file, err := reactFS.Open("index.html")
|
|
if err != nil {
|
|
c.String(http.StatusInternalServerError, err.Error())
|
|
return
|
|
}
|
|
defer file.Close()
|
|
http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file)
|
|
|
|
})
|
|
}
|
|
|
|
// see: proxymanager_api.go
|
|
// add API handler functions
|
|
addApiHandlers(pm)
|
|
|
|
// Disable console color for testing
|
|
gin.DisableConsoleColor()
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler interface
|
|
func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
pm.ginEngine.ServeHTTP(w, r)
|
|
}
|
|
|
|
// StopProcesses acquires a lock and stops all running upstream processes.
|
|
// This is the public method safe for concurrent calls.
|
|
// Unlike Shutdown, this method only stops the processes but doesn't perform
|
|
// a complete shutdown, allowing for process replacement without full termination.
|
|
func (pm *ProxyManager) StopProcesses(strategy StopStrategy) {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
|
|
// stop Processes in parallel
|
|
var wg sync.WaitGroup
|
|
for _, processGroup := range pm.processGroups {
|
|
wg.Add(1)
|
|
go func(processGroup *ProcessGroup) {
|
|
defer wg.Done()
|
|
processGroup.StopProcesses(strategy)
|
|
}(processGroup)
|
|
}
|
|
|
|
wg.Wait()
|
|
}
|
|
|
|
// Shutdown stops all processes managed by this ProxyManager
|
|
func (pm *ProxyManager) Shutdown() {
|
|
pm.Lock()
|
|
defer pm.Unlock()
|
|
|
|
pm.proxyLogger.Debug("Shutdown() called in proxy manager")
|
|
|
|
var wg sync.WaitGroup
|
|
// Send shutdown signal to all process in groups
|
|
for _, processGroup := range pm.processGroups {
|
|
wg.Add(1)
|
|
go func(processGroup *ProcessGroup) {
|
|
defer wg.Done()
|
|
processGroup.Shutdown()
|
|
}(processGroup)
|
|
}
|
|
wg.Wait()
|
|
pm.shutdownCancel()
|
|
}
|
|
|
|
func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) {
|
|
// de-alias the real model name and get a real one
|
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
|
if !found {
|
|
return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel)
|
|
}
|
|
|
|
processGroup := pm.findGroupByModelName(realModelName)
|
|
if processGroup == nil {
|
|
return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel)
|
|
}
|
|
|
|
if processGroup.exclusive {
|
|
pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id)
|
|
for groupId, otherGroup := range pm.processGroups {
|
|
if groupId != processGroup.id && !otherGroup.persistent {
|
|
otherGroup.StopProcesses(StopWaitForInflightRequest)
|
|
}
|
|
}
|
|
}
|
|
|
|
return processGroup, realModelName, nil
|
|
}
|
|
|
|
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|
data := make([]gin.H, 0, len(pm.config.Models))
|
|
createdTime := time.Now().Unix()
|
|
|
|
for id, modelConfig := range pm.config.Models {
|
|
if modelConfig.Unlisted {
|
|
continue
|
|
}
|
|
|
|
record := gin.H{
|
|
"id": id,
|
|
"object": "model",
|
|
"created": createdTime,
|
|
"owned_by": "llama-swap",
|
|
}
|
|
|
|
if name := strings.TrimSpace(modelConfig.Name); name != "" {
|
|
record["name"] = name
|
|
}
|
|
if desc := strings.TrimSpace(modelConfig.Description); desc != "" {
|
|
record["description"] = desc
|
|
}
|
|
|
|
data = append(data, record)
|
|
}
|
|
|
|
// Set CORS headers if origin exists
|
|
if origin := c.GetHeader("Origin"); origin != "" {
|
|
c.Header("Access-Control-Allow-Origin", origin)
|
|
}
|
|
|
|
// Use gin's JSON method which handles content-type and encoding
|
|
c.JSON(http.StatusOK, gin.H{
|
|
"object": "list",
|
|
"data": data,
|
|
})
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|
requestedModel := c.Param("model_id")
|
|
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
|
return
|
|
}
|
|
|
|
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// rewrite the path
|
|
c.Request.URL.Path = c.Param("upstreamPath")
|
|
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
|
return
|
|
}
|
|
|
|
realModelName := c.GetString("ls-real-model-name") // Should be set in MetricsMiddleware
|
|
if realModelName == "" {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "ls-real-model-name not set")
|
|
return
|
|
}
|
|
|
|
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// issue #69 allow custom model names to be sent to upstream
|
|
useModelName := pm.config.Models[realModelName].UseModelName
|
|
if useModelName != "" {
|
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error()))
|
|
return
|
|
}
|
|
}
|
|
|
|
// issue #174 strip parameters from the JSON body
|
|
stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams()
|
|
if err != nil { // just log it and continue
|
|
pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error())
|
|
} else {
|
|
for _, param := range stripParams {
|
|
pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param)
|
|
bodyBytes, err = sjson.DeleteBytes(bodyBytes, param)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param))
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
|
|
// dechunk it as we already have all the body bytes see issue #11
|
|
c.Request.Header.Del("transfer-encoding")
|
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
|
c.Request.ContentLength = int64(len(bodyBytes))
|
|
|
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|
// Parse multipart form
|
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// Get model parameter from the form
|
|
requestedModel := c.Request.FormValue("model")
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data")
|
|
return
|
|
}
|
|
|
|
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
|
return
|
|
}
|
|
|
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
// Create a new buffer for the reconstructed request
|
|
var requestBuffer bytes.Buffer
|
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
|
|
// Copy all form values
|
|
for key, values := range c.Request.MultipartForm.Value {
|
|
for _, value := range values {
|
|
fieldValue := value
|
|
// If this is the model field and we have a profile, use just the model name
|
|
if key == "model" {
|
|
// # issue #69 allow custom model names to be sent to upstream
|
|
useModelName := pm.config.Models[realModelName].UseModelName
|
|
|
|
if useModelName != "" {
|
|
fieldValue = useModelName
|
|
} else {
|
|
fieldValue = requestedModel
|
|
}
|
|
}
|
|
field, err := multipartWriter.CreateFormField(key)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")
|
|
return
|
|
}
|
|
if _, err = field.Write([]byte(fieldValue)); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field")
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// Copy all files from the original request
|
|
for key, fileHeaders := range c.Request.MultipartForm.File {
|
|
for _, fileHeader := range fileHeaders {
|
|
formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file")
|
|
return
|
|
}
|
|
|
|
file, err := fileHeader.Open()
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file")
|
|
return
|
|
}
|
|
|
|
if _, err = io.Copy(formFile, file); err != nil {
|
|
file.Close()
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data")
|
|
return
|
|
}
|
|
file.Close()
|
|
}
|
|
}
|
|
|
|
// Close the multipart writer to finalize the form
|
|
if err := multipartWriter.Close(); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form")
|
|
return
|
|
}
|
|
|
|
// Create a new request with the reconstructed form data
|
|
modifiedReq, err := http.NewRequestWithContext(
|
|
c.Request.Context(),
|
|
c.Request.Method,
|
|
c.Request.URL.String(),
|
|
&requestBuffer,
|
|
)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request")
|
|
return
|
|
}
|
|
|
|
// Copy the headers from the original request
|
|
modifiedReq.Header = c.Request.Header.Clone()
|
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
|
|
|
// set the content length of the body
|
|
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
|
|
|
// Use the modified request for proxying
|
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) {
|
|
acceptHeader := c.GetHeader("Accept")
|
|
|
|
if strings.Contains(acceptHeader, "application/json") {
|
|
c.JSON(statusCode, gin.H{"error": message})
|
|
} else {
|
|
c.String(statusCode, message)
|
|
}
|
|
}
|
|
|
|
func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
|
pm.StopProcesses(StopImmediately)
|
|
c.String(http.StatusOK, "OK")
|
|
}
|
|
|
|
func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) {
|
|
context.Header("Content-Type", "application/json")
|
|
runningProcesses := make([]gin.H, 0) // Default to an empty response.
|
|
|
|
for _, processGroup := range pm.processGroups {
|
|
for _, process := range processGroup.processes {
|
|
if process.CurrentState() == StateReady {
|
|
runningProcesses = append(runningProcesses, gin.H{
|
|
"model": process.ID,
|
|
"state": process.state,
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
// Put the results under the `running` key.
|
|
response := gin.H{
|
|
"running": runningProcesses,
|
|
}
|
|
|
|
context.JSON(http.StatusOK, response) // Always return 200 OK
|
|
}
|
|
|
|
func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup {
|
|
for _, group := range pm.processGroups {
|
|
if group.HasMember(modelName) {
|
|
return group
|
|
}
|
|
}
|
|
return nil
|
|
}
|