diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 67da376..9383f06 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -227,7 +227,7 @@ func (pm *ProxyManager) setupGinEngine() { pm.ginEngine.GET("/upstream", func(c *gin.Context) { c.Redirect(http.StatusFound, "/ui/models") }) - pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) + pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) pm.ginEngine.GET("/running", pm.listRunningProcessesHandler) @@ -393,24 +393,52 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) { } func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { - requestedModel := c.Param("model_id") + upstreamPath := c.Param("upstreamPath") - if requestedModel == "" { + // split the upstream path by / and search for the model name + parts := strings.Split(strings.TrimSpace(upstreamPath), "/") + if len(parts) == 0 { pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") return } - processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) + modelFound := false + searchModelName := "" + var modelName, remainingPath string + for i, part := range parts { + if parts[i] == "" { + continue + } + + if searchModelName == "" { + searchModelName = part + } else { + searchModelName = searchModelName + "/" + parts[i] + } + + if real, ok := pm.config.RealModelName(searchModelName); ok { + modelName = real + remainingPath = "/" + strings.Join(parts[i+1:], "/") + modelFound = true + break + } + } + + if !modelFound { + pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path") + return + } + + processGroup, realModelName, err := pm.swapProcessGroup(modelName) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) return } // rewrite the path - c.Request.URL.Path = c.Param("upstreamPath") + c.Request.URL.Path = remainingPath processGroup.ProxyRequest(realModelName, c.Writer, c.Request) } - func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil {