upstream handler support for model names with forward slash (#298)
The upstream handler would break on model IDs that contained a forward slash. Model IDs like "aaa/bbb" called at upstream/aaa/bbb would result in an error. This commit adds support for model IDs with a forward slash by iteratively searching the path for a match. Fixes: #229
This commit is contained in:
@@ -227,7 +227,7 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||||
c.Redirect(http.StatusFound, "/ui/models")
|
c.Redirect(http.StatusFound, "/ui/models")
|
||||||
})
|
})
|
||||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||||
|
|
||||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||||
@@ -393,24 +393,52 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||||
requestedModel := c.Param("model_id")
|
upstreamPath := c.Param("upstreamPath")
|
||||||
|
|
||||||
if requestedModel == "" {
|
// split the upstream path by / and search for the model name
|
||||||
|
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
|
||||||
|
if len(parts) == 0 {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
modelFound := false
|
||||||
|
searchModelName := ""
|
||||||
|
var modelName, remainingPath string
|
||||||
|
for i, part := range parts {
|
||||||
|
if parts[i] == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if searchModelName == "" {
|
||||||
|
searchModelName = part
|
||||||
|
} else {
|
||||||
|
searchModelName = searchModelName + "/" + parts[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
if real, ok := pm.config.RealModelName(searchModelName); ok {
|
||||||
|
modelName = real
|
||||||
|
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||||
|
modelFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !modelFound {
|
||||||
|
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
c.Request.URL.Path = c.Param("upstreamPath")
|
c.Request.URL.Path = remainingPath
|
||||||
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
|
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user