From b3d331da0d9750defa8fac56008cb6181df2e953 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Sun, 9 Mar 2025 12:41:52 -0700 Subject: [PATCH] Properly strip profile name slug from models fixes (#62) The profile slug in a model name, `profile:model`, is specific to llama-swap. This strips `profile:` out of the model name request so upstreams that expect just `model` work and do not require knowing about the profile slug. --- go.mod | 4 +++ go.sum | 10 ++++++ misc/simple-responder/simple-responder.go | 21 ++++++++++++ proxy/proxymanager.go | 40 +++++++++++++++-------- proxy/proxymanager_test.go | 23 +++++++++++++ 5 files changed, 84 insertions(+), 14 deletions(-) diff --git a/go.mod b/go.mod index 1d88233..59cd5d7 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,10 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect diff --git a/go.sum b/go.sum index bcd2b87..5091543 100644 --- a/go.sum +++ b/go.sum @@ -57,6 +57,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= diff --git a/misc/simple-responder/simple-responder.go b/misc/simple-responder/simple-responder.go index 1108edd..57ff632 100644 --- a/misc/simple-responder/simple-responder.go +++ b/misc/simple-responder/simple-responder.go @@ -12,12 +12,14 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" ) func main() { gin.SetMode(gin.TestMode) // Define a command-line flag for the port port := flag.String("port", "8080", "port to listen on") + expectedModel := flag.String("model", "TheExpectedModel", "model name to expect") // Define a command-line flag for the response message responseMessage := flag.String("respond", "hi", "message to respond with") @@ -41,6 +43,25 @@ func main() { c.String(200, *responseMessage) }) + // for issue #62 to check model name strips profile slug + // has to be one of the openAI API endpoints that llama-swap proxies + // curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}' + r.POST("/v1/audio/speech", func(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"}) + return + } + defer c.Request.Body.Close() + modelName := gjson.GetBytes(body, "model").String() + if modelName != *expectedModel { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)}) + return + } else { + c.JSON(http.StatusOK, gin.H{"message": "ok"}) + } + }) + r.POST("/v1/completions", func(c *gin.Context) { c.Header("Content-Type", "text/plain") c.String(200, *responseMessage) diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 602fe41..2b2c49e 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -13,6 +13,8 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) const ( @@ -224,11 +226,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) { defer pm.Unlock() // Check if requestedModel contains a PROFILE_SPLIT_CHAR - profileName, modelName := "", requestedModel - if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 { - profileName = requestedModel[:idx] - modelName = requestedModel[idx+1:] - } + profileName, modelName := splitRequestedModel(requestedModel) if profileName != "" { if _, found := pm.config.Profiles[profileName]; !found { @@ -344,21 +342,26 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { return } - var requestBody map[string]interface{} - if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { - pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error())) - return - } - model, ok := requestBody["model"].(string) - if !ok { + requestedModel := gjson.GetBytes(bodyBytes, "model").String() + if requestedModel == "" { pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") - return } - if process, err := pm.swapModel(model); err != nil { + if process, err := pm.swapModel(requestedModel); err != nil { pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) return } else { + + // strip + profileName, modelName := splitRequestedModel(requestedModel) + if profileName != "" { + bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error())) + return + } + } + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // dechunk it as we already have all the body bytes see issue #11 @@ -387,3 +390,12 @@ func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { func ProcessKeyName(groupName, modelName string) string { return groupName + PROFILE_SPLIT_CHAR + modelName } + +func splitRequestedModel(requestedModel string) (string, string) { + profileName, modelName := "", requestedModel + if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 { + profileName = requestedModel[:idx] + modelName = requestedModel[idx+1:] + } + return profileName, modelName +} diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index f167443..7388242 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -326,3 +326,26 @@ func TestProxyManager_Unload(t *testing.T) { assert.Equal(t, w.Body.String(), "OK") assert.Len(t, proxy.currentProcesses, 0) } + +// issue 62, strip profile slug from model name +func TestProxyManager_StripProfileSlug(t *testing.T) { + config := &Config{ + HealthCheckTimeout: 15, + Profiles: map[string][]string{ + "test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go + }, + Models: map[string]ModelConfig{ + "TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"), + }, + } + + proxy := New(config) + defer proxy.StopProcesses() + + reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel") + req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "ok") +}