upstream handler support for model names with forward slash (#298)
The upstream handler would break on model IDs that contained a forward slash. Model IDs like "aaa/bbb" called at upstream/aaa/bbb would result in an error. This commit adds support for model IDs with a forward slash by iteratively searching the path for a match. Fixes: #229
This commit is contained in:
@@ -227,7 +227,7 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
pm.ginEngine.GET("/upstream", func(c *gin.Context) {
|
||||
c.Redirect(http.StatusFound, "/ui/models")
|
||||
})
|
||||
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
@@ -393,24 +393,52 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
||||
requestedModel := c.Param("model_id")
|
||||
upstreamPath := c.Param("upstreamPath")
|
||||
|
||||
if requestedModel == "" {
|
||||
// split the upstream path by / and search for the model name
|
||||
parts := strings.Split(strings.TrimSpace(upstreamPath), "/")
|
||||
if len(parts) == 0 {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
modelFound := false
|
||||
searchModelName := ""
|
||||
var modelName, remainingPath string
|
||||
for i, part := range parts {
|
||||
if parts[i] == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if searchModelName == "" {
|
||||
searchModelName = part
|
||||
} else {
|
||||
searchModelName = searchModelName + "/" + parts[i]
|
||||
}
|
||||
|
||||
if real, ok := pm.config.RealModelName(searchModelName); ok {
|
||||
modelName = real
|
||||
remainingPath = "/" + strings.Join(parts[i+1:], "/")
|
||||
modelFound = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !modelFound {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "model id required in path")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(modelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
// rewrite the path
|
||||
c.Request.URL.Path = c.Param("upstreamPath")
|
||||
c.Request.URL.Path = remainingPath
|
||||
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user