package proxy import ( "bytes" "encoding/json" "fmt" "io" "mime/multipart" "net/http" "os" "sort" "strconv" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) const ( PROFILE_SPLIT_CHAR = ":" ) type ProxyManager struct { sync.Mutex config *Config currentProcesses map[string]*Process ginEngine *gin.Engine // logging proxyLogger *LogMonitor upstreamLogger *LogMonitor muxLogger *LogMonitor } func New(config *Config) *ProxyManager { // set up loggers stdoutLogger := NewLogMonitorWriter(os.Stdout) upstreamLogger := NewLogMonitorWriter(stdoutLogger) proxyLogger := NewLogMonitorWriter(stdoutLogger) if config.LogRequests { proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.") } switch strings.ToLower(strings.TrimSpace(config.LogLevel)) { case "debug": proxyLogger.SetLogLevel(LevelDebug) case "info": proxyLogger.SetLogLevel(LevelInfo) case "warn": proxyLogger.SetLogLevel(LevelWarn) case "error": proxyLogger.SetLogLevel(LevelError) default: proxyLogger.SetLogLevel(LevelInfo) } pm := &ProxyManager{ config: config, currentProcesses: make(map[string]*Process), ginEngine: gin.New(), proxyLogger: proxyLogger, muxLogger: stdoutLogger, upstreamLogger: upstreamLogger, } pm.ginEngine.Use(func(c *gin.Context) { // Start timer start := time.Now() // capture these because /upstream/:model rewrites them in c.Next() clientIP := c.ClientIP() method := c.Request.Method path := c.Request.URL.Path // Process request c.Next() // Stop timer duration := time.Since(start) statusCode := c.Writer.Status() bodySize := c.Writer.Size() pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v", clientIP, method, path, c.Request.Proto, statusCode, bodySize, c.Request.UserAgent(), duration, ) }) // see: issue: #81, #77 and #42 for CORS issues // respond with permissive OPTIONS for any endpoint pm.ginEngine.Use(func(c *gin.Context) { if c.Request.Method == "OPTIONS" { c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") // allow whatever the client requested by default if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" { sanitized := SanitizeAccessControlRequestHeaderValues(headers) c.Header("Access-Control-Allow-Headers", sanitized) } else { c.Header( "Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, X-Requested-With", ) } c.Header("Access-Control-Max-Age", "86400") c.AbortWithStatus(http.StatusNoContent) return } c.Next() }) // Set up routes using the Gin engine pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler) // Support legacy /v1/completions api, see issue #12 pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler) // Support embeddings pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler) // Support audio/speech endpoint pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler) pm.ginEngine.GET("/v1/models", pm.listModelsHandler) // in proxymanager_loghandlers.go pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler) pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/upstream", pm.upstreamIndex) pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) pm.ginEngine.GET("/running", pm.listRunningProcessesHandler) pm.ginEngine.GET("/", func(c *gin.Context) { // Set the Content-Type header to text/html c.Header("Content-Type", "text/html") // Write the embedded HTML content to the response htmlData, err := getHTMLFile("index.html") if err != nil { c.String(http.StatusInternalServerError, err.Error()) return } _, err = c.Writer.Write(htmlData) if err != nil { c.String(http.StatusInternalServerError, fmt.Sprintf("failed to write response: %v", err)) return } }) pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) { if data, err := getHTMLFile("favicon.ico"); err == nil { c.Data(http.StatusOK, "image/x-icon", data) } else { c.String(http.StatusInternalServerError, err.Error()) } }) // Disable console color for testing gin.DisableConsoleColor() return pm } func (pm *ProxyManager) Run(addr ...string) error { return pm.ginEngine.Run(addr...) } func (pm *ProxyManager) HandlerFunc(w http.ResponseWriter, r *http.Request) { pm.ginEngine.ServeHTTP(w, r) } func (pm *ProxyManager) StopProcesses() { pm.Lock() defer pm.Unlock() pm.stopProcesses() } // for internal usage func (pm *ProxyManager) stopProcesses() { if len(pm.currentProcesses) == 0 { return } // stop Processes in parallel var wg sync.WaitGroup for _, process := range pm.currentProcesses { wg.Add(1) go func(process *Process) { defer wg.Done() process.Stop() }(process) } wg.Wait() pm.currentProcesses = make(map[string]*Process) } // Shutdown is called to shutdown all upstream processes // when llama-swap is shutting down. func (pm *ProxyManager) Shutdown() { pm.Lock() defer pm.Unlock() // shutdown process in parallel var wg sync.WaitGroup for _, process := range pm.currentProcesses { wg.Add(1) go func(process *Process) { defer wg.Done() process.Shutdown() }(process) } wg.Wait() } func (pm *ProxyManager) listModelsHandler(c *gin.Context) { data := []interface{}{} for id, modelConfig := range pm.config.Models { if modelConfig.Unlisted { continue } data = append(data, map[string]interface{}{ "id": id, "object": "model", "created": time.Now().Unix(), "owned_by": "llama-swap", }) } // Set the Content-Type header to application/json c.Header("Content-Type", "application/json") if origin := c.Request.Header.Get("Origin"); origin != "" { c.Header("Access-Control-Allow-Origin", origin) } // Encode the data as JSON and write it to the response writer if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error encoding JSON %s", err.Error())) return } } func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) { pm.Lock() defer pm.Unlock() // Check if requestedModel contains a PROFILE_SPLIT_CHAR profileName, modelName := splitRequestedModel(requestedModel) if profileName != "" { if _, found := pm.config.Profiles[profileName]; !found { return nil, fmt.Errorf("model group not found %s", profileName) } } // de-alias the real model name and get a real one realModelName, found := pm.config.RealModelName(modelName) if !found { return nil, fmt.Errorf("could not find modelID for %s", requestedModel) } // check if model is part of the profile if profileName != "" { found := false for _, item := range pm.config.Profiles[profileName] { if item == realModelName { found = true break } } if !found { return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName) } } // exit early when already running, otherwise stop everything and swap requestedProcessKey := ProcessKeyName(profileName, realModelName) if process, found := pm.currentProcesses[requestedProcessKey]; found { pm.proxyLogger.Debugf("No-swap, using existing process for model [%s]", requestedModel) return process, nil } // stop all running models pm.proxyLogger.Infof("Swapping model to [%s]", requestedModel) pm.stopProcesses() if profileName == "" { modelConfig, modelID, found := pm.config.FindConfig(realModelName) if !found { return nil, fmt.Errorf("could not find configuration for %s", realModelName) } process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger) processKey := ProcessKeyName(profileName, modelID) pm.currentProcesses[processKey] = process } else { for _, modelName := range pm.config.Profiles[profileName] { if realModelName, found := pm.config.RealModelName(modelName); found { modelConfig, modelID, found := pm.config.FindConfig(realModelName) if !found { return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName) } process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.upstreamLogger, pm.proxyLogger) processKey := ProcessKeyName(profileName, modelID) pm.currentProcesses[processKey] = process } } } // requestedProcessKey should exist due to swap return pm.currentProcesses[requestedProcessKey], nil } func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { requestedModel := c.Param("model_id") if requestedModel == "" { pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") return } if process, err := pm.swapModel(requestedModel); err != nil { pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) } else { // rewrite the path c.Request.URL.Path = c.Param("upstreamPath") process.ProxyRequest(c.Writer, c.Request) } } func (pm *ProxyManager) upstreamIndex(c *gin.Context) { var html strings.Builder html.WriteString("\n