package proxy import ( "bytes" "encoding/json" "fmt" "io" "mime/multipart" "net/http" "os" "sort" "strconv" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) const ( PROFILE_SPLIT_CHAR = ":" ) type ProxyManager struct { sync.Mutex config Config ginEngine *gin.Engine // logging proxyLogger *LogMonitor upstreamLogger *LogMonitor muxLogger *LogMonitor processGroups map[string]*ProcessGroup } func New(config Config) *ProxyManager { // set up loggers stdoutLogger := NewLogMonitorWriter(os.Stdout) upstreamLogger := NewLogMonitorWriter(stdoutLogger) proxyLogger := NewLogMonitorWriter(stdoutLogger) if config.LogRequests { proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.") } switch strings.ToLower(strings.TrimSpace(config.LogLevel)) { case "debug": proxyLogger.SetLogLevel(LevelDebug) upstreamLogger.SetLogLevel(LevelDebug) case "info": proxyLogger.SetLogLevel(LevelInfo) upstreamLogger.SetLogLevel(LevelInfo) case "warn": proxyLogger.SetLogLevel(LevelWarn) upstreamLogger.SetLogLevel(LevelWarn) case "error": proxyLogger.SetLogLevel(LevelError) upstreamLogger.SetLogLevel(LevelError) default: proxyLogger.SetLogLevel(LevelInfo) upstreamLogger.SetLogLevel(LevelInfo) } pm := &ProxyManager{ config: config, ginEngine: gin.New(), proxyLogger: proxyLogger, muxLogger: stdoutLogger, upstreamLogger: upstreamLogger, processGroups: make(map[string]*ProcessGroup), } // create the process groups for groupID := range config.Groups { processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger) pm.processGroups[groupID] = processGroup } pm.ginEngine.Use(func(c *gin.Context) { // Start timer start := time.Now() // capture these because /upstream/:model rewrites them in c.Next() clientIP := c.ClientIP() method := c.Request.Method path := c.Request.URL.Path // Process request c.Next() // Stop timer duration := time.Since(start) statusCode := c.Writer.Status() bodySize := c.Writer.Size() pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v", clientIP, method, path, c.Request.Proto, statusCode, bodySize, c.Request.UserAgent(), duration, ) }) // see: issue: #81, #77 and #42 for CORS issues // respond with permissive OPTIONS for any endpoint pm.ginEngine.Use(func(c *gin.Context) { if c.Request.Method == "OPTIONS" { c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") // allow whatever the client requested by default if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" { sanitized := SanitizeAccessControlRequestHeaderValues(headers) c.Header("Access-Control-Allow-Headers", sanitized) } else { c.Header( "Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, X-Requested-With", ) } c.Header("Access-Control-Max-Age", "86400") c.AbortWithStatus(http.StatusNoContent) return } c.Next() }) // Set up routes using the Gin engine pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler) // Support legacy /v1/completions api, see issue #12 pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler) // Support embeddings pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler) // Support audio/speech endpoint pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler) pm.ginEngine.GET("/v1/models", pm.listModelsHandler) // in proxymanager_loghandlers.go pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler) pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/upstream", pm.upstreamIndex) pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) pm.ginEngine.GET("/running", pm.listRunningProcessesHandler) pm.ginEngine.GET("/", func(c *gin.Context) { // Set the Content-Type header to text/html c.Header("Content-Type", "text/html") // Write the embedded HTML content to the response htmlData, err := getHTMLFile("index.html") if err != nil { c.String(http.StatusInternalServerError, err.Error()) return } _, err = c.Writer.Write(htmlData) if err != nil { c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err)) return } }) pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) { if data, err := getHTMLFile("favicon.ico"); err == nil { c.Data(http.StatusOK, "image/x-icon", data) } else { c.String(http.StatusInternalServerError, err.Error()) } }) // Disable console color for testing gin.DisableConsoleColor() return pm } func (pm *ProxyManager) Run(addr ...string) error { return pm.ginEngine.Run(addr...) } func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) { pm.ginEngine.ServeHTTP(w, r) } func (pm *ProxyManager) StopProcesses() { pm.Lock() defer pm.Unlock() // stop Processes in parallel var wg sync.WaitGroup for _, processGroup := range pm.processGroups { wg.Add(1) go func(processGroup *ProcessGroup) { defer wg.Done() processGroup.stopProcesses() }(processGroup) } wg.Wait() } // Shutdown is called to shutdown all upstream processes // when llama-swap is shutting down. func (pm *ProxyManager) Shutdown() { pm.Lock() defer pm.Unlock() pm.proxyLogger.Debug("Shutdown() called in proxy manager") var wg sync.WaitGroup // Send shutdown signal to all process in groups for _, processGroup := range pm.processGroups { wg.Add(1) go func(processGroup *ProcessGroup) { defer wg.Done() processGroup.Shutdown() }(processGroup) } wg.Wait() } func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) { // de-alias the real model name and get a real one realModelName, found := pm.config.RealModelName(requestedModel) if !found { return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel) } processGroup := pm.findGroupByModelName(realModelName) if processGroup == nil { return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel) } if processGroup.exclusive { pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id) for groupId, otherGroup := range pm.processGroups { if groupId != processGroup.id && !otherGroup.persistent { otherGroup.StopProcesses() } } } return processGroup, realModelName, nil } func (pm *ProxyManager) listModelsHandler(c *gin.Context) { data := []interface{}{} for id, modelConfig := range pm.config.Models { if modelConfig.Unlisted { continue } data = append(data, map[string]interface{}{ "id": id, "object": "model", "created": time.Now().Unix(), "owned_by": "llama-swap", }) } // Set the Content-Type header to application/json c.Header("Content-Type", "application/json") if origin := c.Request.Header.Get("Origin"); origin != "" { c.Header("Access-Control-Allow-Origin", origin) } // Encode the data as JSON and write it to the response writer if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error())) return } } func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { requestedModel := c.Param("model_id") if requestedModel == "" { pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") return } processGroup, _, err := pm.swapProcessGroup(requestedModel) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) } // rewrite the path c.Request.URL.Path = c.Param("upstreamPath") processGroup.ProxyRequest(requestedModel, c.Writer, c.Request) } func (pm *ProxyManager) upstreamIndex(c *gin.Context) { var html strings.Builder html.WriteString("\n