Add /upstream endpoint (#30)

* remove catch-all route to upstream proxy (it was broken anyways)
* add /upstream/:model_id to swap and route to upstream path
* add /upstream HTML endpoint and unlisted option
* add /upstream endpoint to show a list of available models
* add `unlisted` configuration option to omit a model from /v1/models and /upstream lists
* add favicon.ico
This commit is contained in:
Benson Wong
2024-12-17 14:37:44 -08:00
committed by GitHub
parent 7183f6b43d
commit 891f6a5b5a
7 changed files with 78 additions and 22 deletions

View File

@@ -2,10 +2,12 @@ package proxy
import (
"bytes"
"embed"
"encoding/json"
"fmt"
"io"
"net/http"
"sort"
"strconv"
"strings"
"sync"
@@ -18,6 +20,15 @@ const (
PROFILE_SPLIT_CHAR = ":"
)
//go:embed html/favicon.ico
var faviconData []byte
//go:embed html/logs.html
var logsHTML []byte
// make sure embed is kept there by the IDE auto-package importer
var _ = embed.FS{}
type ProxyManager struct {
sync.Mutex
@@ -48,7 +59,12 @@ func New(config *Config) *ProxyManager {
pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler)
pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE)
pm.ginEngine.NoRoute(pm.proxyNoRouteHandler)
pm.ginEngine.GET("/upstream", pm.upstreamIndex)
pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/favicon.ico", func(c *gin.Context) {
c.Data(http.StatusOK, "image/x-icon", faviconData)
})
// Disable console color for testing
gin.DisableConsoleColor()
@@ -86,7 +102,11 @@ func (pm *ProxyManager) stopProcesses() {
func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
data := []interface{}{}
for id := range pm.config.Models {
for id, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
data = append(data, map[string]interface{}{
"id": id,
"object": "model",
@@ -113,7 +133,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
pm.Lock()
defer pm.Unlock()
// Check if requestedModel contains a /
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
@@ -170,6 +190,48 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
return pm.currentProcesses[requestedProcessKey], nil
}
func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
requestedModel := c.Param("model_id")
if requestedModel == "" {
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("model id required in path"))
return
}
if process, err := pm.swapModel(requestedModel); err != nil {
c.AbortWithError(http.StatusNotFound, fmt.Errorf("unable to swap to model, %s", err.Error()))
} else {
// rewrite the path
c.Request.URL.Path = c.Param("upstreamPath")
process.ProxyRequest(c.Writer, c.Request)
}
}
func (pm *ProxyManager) upstreamIndex(c *gin.Context) {
var html strings.Builder
html.WriteString("<!doctype HTML>\n<html><body><h1>Available Models</h1><ul>")
// Extract keys and sort them
var modelIDs []string
for modelID, modelConfig := range pm.config.Models {
if modelConfig.Unlisted {
continue
}
modelIDs = append(modelIDs, modelID)
}
sort.Strings(modelIDs)
// Iterate over sorted keys
for _, modelID := range modelIDs {
html.WriteString(fmt.Sprintf("<li><a href=\"/upstream/%s\">%s</a></li>", modelID, modelID))
}
html.WriteString("</ul></body></html>")
c.Header("Content-Type", "text/html")
c.String(http.StatusOK, html.String())
}
func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
@@ -201,16 +263,6 @@ func (pm *ProxyManager) proxyChatRequestHandler(c *gin.Context) {
}
}
func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
// since maps are unordered, just use the first available process if one exists
for _, process := range pm.currentProcesses {
process.ProxyRequest(c.Writer, c.Request)
return
}
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
}
func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName
}