Properly strip profile name slug from models fixes (#62)

The profile slug in a model name, `profile:model`, is specific to
llama-swap. This strips `profile:` out of the model name request so
upstreams that expect just `model` work and do not require knowing about
the profile slug.
This commit is contained in:
Benson Wong
2025-03-09 12:41:52 -07:00
committed by GitHub
parent 62275e078d
commit b3d331da0d
5 changed files with 84 additions and 14 deletions

4
go.mod
View File

@@ -29,6 +29,10 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect

10
go.sum
View File

@@ -57,6 +57,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=

View File

@@ -12,12 +12,14 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
func main() {
gin.SetMode(gin.TestMode)
// Define a command-line flag for the port
port := flag.String("port", "8080", "port to listen on")
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
// Define a command-line flag for the response message
responseMessage := flag.String("respond", "hi", "message to respond with")
@@ -41,6 +43,25 @@ func main() {
c.String(200, *responseMessage)
})
// for issue #62 to check model name strips profile slug
// has to be one of the openAI API endpoints that llama-swap proxies
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
r.POST("/v1/audio/speech", func(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
return
}
defer c.Request.Body.Close()
modelName := gjson.GetBytes(body, "model").String()
if modelName != *expectedModel {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
return
} else {
c.JSON(http.StatusOK, gin.H{"message": "ok"})
}
})
r.POST("/v1/completions", func(c *gin.Context) {
c.Header("Content-Type", "text/plain")
c.String(200, *responseMessage)

View File

@@ -13,6 +13,8 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
@@ -224,11 +226,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
defer pm.Unlock()
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
if _, found := pm.config.Profiles[profileName]; !found {
@@ -344,21 +342,26 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return
}
var requestBody map[string]interface{}
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
return
}
model, ok := requestBody["model"].(string)
if !ok {
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
if process, err := pm.swapModel(model); err != nil {
if process, err := pm.swapModel(requestedModel); err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
} else {
// strip
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
@@ -387,3 +390,12 @@ func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName
}
func splitRequestedModel(requestedModel string) (string, string) {
profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:]
}
return profileName, modelName
}

View File

@@ -326,3 +326,26 @@ func TestProxyManager_Unload(t *testing.T) {
assert.Equal(t, w.Body.String(), "OK")
assert.Len(t, proxy.currentProcesses, 0)
}
// issue 62, strip profile slug from model name
func TestProxyManager_StripProfileSlug(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
},
Models: map[string]ModelConfig{
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
},
}
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), "ok")
}