package proxy import ( "bytes" "context" "fmt" "io" "mime/multipart" "net/http" "os" "sort" "strconv" "strings" "sync" "time" "github.com/gin-gonic/gin" "github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/proxy/config" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) const ( PROFILE_SPLIT_CHAR = ":" ) type ProxyManager struct { sync.Mutex config config.Config ginEngine *gin.Engine // logging proxyLogger *LogMonitor upstreamLogger *LogMonitor muxLogger *LogMonitor metricsMonitor *MetricsMonitor processGroups map[string]*ProcessGroup // shutdown signaling shutdownCtx context.Context shutdownCancel context.CancelFunc } func New(config config.Config) *ProxyManager { // set up loggers stdoutLogger := NewLogMonitorWriter(os.Stdout) upstreamLogger := NewLogMonitorWriter(stdoutLogger) proxyLogger := NewLogMonitorWriter(stdoutLogger) if config.LogRequests { proxyLogger.Warn("LogRequests configuration is deprecated. Use logLevel instead.") } switch strings.ToLower(strings.TrimSpace(config.LogLevel)) { case "debug": proxyLogger.SetLogLevel(LevelDebug) upstreamLogger.SetLogLevel(LevelDebug) case "info": proxyLogger.SetLogLevel(LevelInfo) upstreamLogger.SetLogLevel(LevelInfo) case "warn": proxyLogger.SetLogLevel(LevelWarn) upstreamLogger.SetLogLevel(LevelWarn) case "error": proxyLogger.SetLogLevel(LevelError) upstreamLogger.SetLogLevel(LevelError) default: proxyLogger.SetLogLevel(LevelInfo) upstreamLogger.SetLogLevel(LevelInfo) } shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) pm := &ProxyManager{ config: config, ginEngine: gin.New(), proxyLogger: proxyLogger, muxLogger: stdoutLogger, upstreamLogger: upstreamLogger, metricsMonitor: NewMetricsMonitor(&config), processGroups: make(map[string]*ProcessGroup), shutdownCtx: shutdownCtx, shutdownCancel: shutdownCancel, } // create the process groups for groupID := range config.Groups { processGroup := NewProcessGroup(groupID, config, proxyLogger, upstreamLogger) pm.processGroups[groupID] = processGroup } pm.setupGinEngine() // run any startup hooks if len(config.Hooks.OnStartup.Preload) > 0 { // do it in the background, don't block startup -- not sure if good idea yet go func() { discardWriter := &DiscardWriter{} for _, realModelName := range config.Hooks.OnStartup.Preload { proxyLogger.Infof("Preloading model: %s", realModelName) processGroup, _, err := pm.swapProcessGroup(realModelName) if err != nil { event.Emit(ModelPreloadedEvent{ ModelName: realModelName, Success: false, }) proxyLogger.Errorf("Failed to preload model %s: %v", realModelName, err) continue } else { req, _ := http.NewRequest("GET", "/", nil) processGroup.ProxyRequest(realModelName, discardWriter, req) event.Emit(ModelPreloadedEvent{ ModelName: realModelName, Success: true, }) } } }() } return pm } func (pm *ProxyManager) setupGinEngine() { pm.ginEngine.Use(func(c *gin.Context) { // don't log the Wake on Lan proxy health check if c.Request.URL.Path == "/wol-health" { c.Next() return } // Start timer start := time.Now() // capture these because /upstream/:model rewrites them in c.Next() clientIP := c.ClientIP() method := c.Request.Method path := c.Request.URL.Path // Process request c.Next() // Stop timer duration := time.Since(start) statusCode := c.Writer.Status() bodySize := c.Writer.Size() pm.proxyLogger.Infof("Request %s \"%s %s %s\" %d %d \"%s\" %v", clientIP, method, path, c.Request.Proto, statusCode, bodySize, c.Request.UserAgent(), duration, ) }) // see: issue: #81, #77 and #42 for CORS issues // respond with permissive OPTIONS for any endpoint pm.ginEngine.Use(func(c *gin.Context) { if c.Request.Method == "OPTIONS" { c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") // allow whatever the client requested by default if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" { sanitized := SanitizeAccessControlRequestHeaderValues(headers) c.Header("Access-Control-Allow-Headers", sanitized) } else { c.Header( "Access-Control-Allow-Headers", "Content-Type, Authorization, Accept, X-Requested-With", ) } c.Header("Access-Control-Max-Age", "86400") c.AbortWithStatus(http.StatusNoContent) return } c.Next() }) mm := MetricsMiddleware(pm) // Set up routes using the Gin engine pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler) // Support legacy /v1/completions api, see issue #12 pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler) // Support embeddings and reranking pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler) // llama-server's /reranking endpoint + aliases pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler) // llama-server's /infill endpoint for code infilling pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler) // llama-server's /completion endpoint pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler) // Support audio/speech endpoint pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler) pm.ginEngine.GET("/v1/models", pm.listModelsHandler) // in proxymanager_loghandlers.go pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler) /** * User Interface Endpoints */ pm.ginEngine.GET("/", func(c *gin.Context) { c.Redirect(http.StatusFound, "/ui") }) pm.ginEngine.GET("/upstream", func(c *gin.Context) { c.Redirect(http.StatusFound, "/ui/models") }) pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) pm.ginEngine.GET("/running", pm.listRunningProcessesHandler) pm.ginEngine.GET("/health", func(c *gin.Context) { c.String(http.StatusOK, "OK") }) // see cmd/wol-proxy/wol-proxy.go, not logged pm.ginEngine.GET("/wol-health", func(c *gin.Context) { c.String(http.StatusOK, "OK") }) pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) { if data, err := reactStaticFS.ReadFile("ui_dist/favicon.ico"); err == nil { c.Data(http.StatusOK, "image/x-icon", data) } else { c.String(http.StatusInternalServerError, err.Error()) } }) reactFS, err := GetReactFS() if err != nil { pm.proxyLogger.Errorf("Failed to load React filesystem: %v", err) } else { // serve files that exist under /ui/* pm.ginEngine.StaticFS("/ui", reactFS) // server SPA for UI under /ui/* pm.ginEngine.NoRoute(func(c *gin.Context) { if !strings.HasPrefix(c.Request.URL.Path, "/ui") { c.AbortWithStatus(http.StatusNotFound) return } file, err := reactFS.Open("index.html") if err != nil { c.String(http.StatusInternalServerError, err.Error()) return } defer file.Close() http.ServeContent(c.Writer, c.Request, "index.html", time.Now(), file) }) } // see: proxymanager_api.go // add API handler functions addApiHandlers(pm) // Disable console color for testing gin.DisableConsoleColor() } // ServeHTTP implements http.Handler interface func (pm *ProxyManager) ServeHTTP(w http.ResponseWriter, r *http.Request) { pm.ginEngine.ServeHTTP(w, r) } // StopProcesses acquires a lock and stops all running upstream processes. // This is the public method safe for concurrent calls. // Unlike Shutdown, this method only stops the processes but doesn't perform // a complete shutdown, allowing for process replacement without full termination. func (pm *ProxyManager) StopProcesses(strategy StopStrategy) { pm.Lock() defer pm.Unlock() // stop Processes in parallel var wg sync.WaitGroup for _, processGroup := range pm.processGroups { wg.Add(1) go func(processGroup *ProcessGroup) { defer wg.Done() processGroup.StopProcesses(strategy) }(processGroup) } wg.Wait() } // Shutdown stops all processes managed by this ProxyManager func (pm *ProxyManager) Shutdown() { pm.Lock() defer pm.Unlock() pm.proxyLogger.Debug("Shutdown() called in proxy manager") var wg sync.WaitGroup // Send shutdown signal to all process in groups for _, processGroup := range pm.processGroups { wg.Add(1) go func(processGroup *ProcessGroup) { defer wg.Done() processGroup.Shutdown() }(processGroup) } wg.Wait() pm.shutdownCancel() } func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) { // de-alias the real model name and get a real one realModelName, found := pm.config.RealModelName(requestedModel) if !found { return nil, realModelName, fmt.Errorf("could not find real modelID for %s", requestedModel) } processGroup := pm.findGroupByModelName(realModelName) if processGroup == nil { return nil, realModelName, fmt.Errorf("could not find process group for model %s", requestedModel) } if processGroup.exclusive { pm.proxyLogger.Debugf("Exclusive mode for group %s, stopping other process groups", processGroup.id) for groupId, otherGroup := range pm.processGroups { if groupId != processGroup.id && !otherGroup.persistent { otherGroup.StopProcesses(StopWaitForInflightRequest) } } } return processGroup, realModelName, nil } func (pm *ProxyManager) listModelsHandler(c *gin.Context) { data := make([]gin.H, 0, len(pm.config.Models)) createdTime := time.Now().Unix() for id, modelConfig := range pm.config.Models { if modelConfig.Unlisted { continue } record := gin.H{ "id": id, "object": "model", "created": createdTime, "owned_by": "llama-swap", } if name := strings.TrimSpace(modelConfig.Name); name != "" { record["name"] = name } if desc := strings.TrimSpace(modelConfig.Description); desc != "" { record["description"] = desc } // Add metadata if present if len(modelConfig.Metadata) > 0 { record["meta"] = gin.H{ "llamaswap": modelConfig.Metadata, } } data = append(data, record) } // Sort by the "id" key sort.Slice(data, func(i, j int) bool { si, _ := data[i]["id"].(string) sj, _ := data[j]["id"].(string) return si < sj }) // Set CORS headers if origin exists if origin := c.GetHeader("Origin"); origin != "" { c.Header("Access-Control-Allow-Origin", origin) } // Use gin's JSON method which handles content-type and encoding c.JSON(http.StatusOK, gin.H{ "object": "list", "data": data, }) } func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { upstreamPath := c.Param("upstreamPath") // split the upstream path by / and search for the model name parts := strings.Split(strings.TrimSpace(upstreamPath), "/") if len(parts) == 0 { pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") return } modelFound := false searchModelName := "" var modelName, remainingPath string for i, part := range parts { if parts[i] == "" { continue } if searchModelName == "" { searchModelName = part } else { searchModelName = searchModelName + "/" + parts[i] } if real, ok := pm.config.RealModelName(searchModelName); ok { modelName = real remainingPath = "/" + strings.Join(parts[i+1:], "/") modelFound = true // Check if this is exactly a model name with no additional path // and doesn't end with a trailing slash if remainingPath == "/" && !strings.HasSuffix(upstreamPath, "/") { // Build new URL with query parameters preserved newPath := "/upstream/" + searchModelName + "/" if c.Request.URL.RawQuery != "" { newPath += "?" + c.Request.URL.RawQuery } // Use 308 for non-GET/HEAD requests to preserve method if c.Request.Method == http.MethodGet || c.Request.Method == http.MethodHead { c.Redirect(http.StatusMovedPermanently, newPath) } else { c.Redirect(http.StatusPermanentRedirect, newPath) } return } break } } if !modelFound { pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") return } processGroup, realModelName, err := pm.swapProcessGroup(modelName) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) return } // rewrite the path c.Request.URL.Path = remainingPath processGroup.ProxyRequest(realModelName, c.Writer, c.Request) } func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil { pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body") return } requestedModel := gjson.GetBytes(bodyBytes, "model").String() if requestedModel == "" { pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") return } realModelName, found := pm.config.RealModelName(requestedModel) if !found { pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel)) return } processGroup, _, err := pm.swapProcessGroup(realModelName) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) return } // issue #69 allow custom model names to be sent to upstream useModelName := pm.config.Models[realModelName].UseModelName if useModelName != "" { bodyBytes, err = sjson.SetBytes(bodyBytes, "model", useModelName) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error rewriting model name in JSON: %s", err.Error())) return } } // issue #174 strip parameters from the JSON body stripParams, err := pm.config.Models[realModelName].Filters.SanitizedStripParams() if err != nil { // just log it and continue pm.proxyLogger.Errorf("Error sanitizing strip params string: %s, %s", pm.config.Models[realModelName].Filters.StripParams, err.Error()) } else { for _, param := range stripParams { pm.proxyLogger.Debugf("<%s> stripping param: %s", realModelName, param) bodyBytes, err = sjson.DeleteBytes(bodyBytes, param) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error deleting parameter %s from request", param)) return } } } c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // dechunk it as we already have all the body bytes see issue #11 c.Request.Header.Del("transfer-encoding") c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes))) c.Request.ContentLength = int64(len(bodyBytes)) if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) return } } func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { // Parse multipart form if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) return } // Get model parameter from the form requestedModel := c.Request.FormValue("model") if requestedModel == "" { pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' parameter in form data") return } processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) return } // We need to reconstruct the multipart form in any case since the body is consumed // Create a new buffer for the reconstructed request var requestBuffer bytes.Buffer multipartWriter := multipart.NewWriter(&requestBuffer) // Copy all form values for key, values := range c.Request.MultipartForm.Value { for _, value := range values { fieldValue := value // If this is the model field and we have a profile, use just the model name if key == "model" { // # issue #69 allow custom model names to be sent to upstream useModelName := pm.config.Models[realModelName].UseModelName if useModelName != "" { fieldValue = useModelName } else { fieldValue = requestedModel } } field, err := multipartWriter.CreateFormField(key) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field") return } if _, err = field.Write([]byte(fieldValue)); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, "error writing form field") return } } } // Copy all files from the original request for key, fileHeaders := range c.Request.MultipartForm.File { for _, fileHeader := range fileHeaders { formFile, err := multipartWriter.CreateFormFile(key, fileHeader.Filename) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form file") return } file, err := fileHeader.Open() if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, "error opening uploaded file") return } if _, err = io.Copy(formFile, file); err != nil { file.Close() pm.sendErrorResponse(c, http.StatusInternalServerError, "error copying file data") return } file.Close() } } // Close the multipart writer to finalize the form if err := multipartWriter.Close(); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, "error finalizing multipart form") return } // Create a new request with the reconstructed form data modifiedReq, err := http.NewRequestWithContext( c.Request.Context(), c.Request.Method, c.Request.URL.String(), &requestBuffer, ) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, "error creating modified request") return } // Copy the headers from the original request modifiedReq.Header = c.Request.Header.Clone() modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) // set the content length of the body modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len())) modifiedReq.ContentLength = int64(requestBuffer.Len()) // Use the modified request for proxying if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) return } } func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, message string) { acceptHeader := c.GetHeader("Accept") if strings.Contains(acceptHeader, "application/json") { c.JSON(statusCode, gin.H{"error": message}) } else { c.String(statusCode, message) } } func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { pm.StopProcesses(StopImmediately) c.String(http.StatusOK, "OK") } func (pm *ProxyManager) listRunningProcessesHandler(context *gin.Context) { context.Header("Content-Type", "application/json") runningProcesses := make([]gin.H, 0) // Default to an empty response. for _, processGroup := range pm.processGroups { for _, process := range processGroup.processes { if process.CurrentState() == StateReady { runningProcesses = append(runningProcesses, gin.H{ "model": process.ID, "state": process.state, }) } } } // Put the results under the `running` key. response := gin.H{ "running": runningProcesses, } context.JSON(http.StatusOK, response) // Always return 200 OK } func (pm *ProxyManager) findGroupByModelName(modelName string) *ProcessGroup { for _, group := range pm.processGroups { if group.HasMember(modelName) { return group } } return nil }