Properly strip profile name slug from models fixes (#62)
The profile slug in a model name, `profile:model`, is specific to llama-swap. This strips `profile:` out of the model name request so upstreams that expect just `model` work and do not require knowing about the profile slug.
This commit is contained in:
4
go.mod
4
go.mod
@@ -29,6 +29,10 @@ require (
|
|||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
|
||||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||||
|
github.com/tidwall/gjson v1.18.0 // indirect
|
||||||
|
github.com/tidwall/match v1.1.1 // indirect
|
||||||
|
github.com/tidwall/pretty v1.2.1 // indirect
|
||||||
|
github.com/tidwall/sjson v1.2.5 // indirect
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||||
golang.org/x/arch v0.8.0 // indirect
|
golang.org/x/arch v0.8.0 // indirect
|
||||||
|
|||||||
10
go.sum
10
go.sum
@@ -57,6 +57,16 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||||
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
|
||||||
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
|
||||||
|
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
|
||||||
|
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
|
||||||
|
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
|
||||||
|
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
|
||||||
|
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
|
||||||
|
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
|
||||||
|
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
|
||||||
|
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||||
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
|
||||||
|
|||||||
@@ -12,12 +12,14 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
// Define a command-line flag for the port
|
// Define a command-line flag for the port
|
||||||
port := flag.String("port", "8080", "port to listen on")
|
port := flag.String("port", "8080", "port to listen on")
|
||||||
|
expectedModel := flag.String("model", "TheExpectedModel", "model name to expect")
|
||||||
|
|
||||||
// Define a command-line flag for the response message
|
// Define a command-line flag for the response message
|
||||||
responseMessage := flag.String("respond", "hi", "message to respond with")
|
responseMessage := flag.String("respond", "hi", "message to respond with")
|
||||||
@@ -41,6 +43,25 @@ func main() {
|
|||||||
c.String(200, *responseMessage)
|
c.String(200, *responseMessage)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// for issue #62 to check model name strips profile slug
|
||||||
|
// has to be one of the openAI API endpoints that llama-swap proxies
|
||||||
|
// curl http://localhost:8080/v1/audio/speech -d '{"model":"profile:TheExpectedModel"}'
|
||||||
|
r.POST("/v1/audio/speech", func(c *gin.Context) {
|
||||||
|
body, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read request body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer c.Request.Body.Close()
|
||||||
|
modelName := gjson.GetBytes(body, "model").String()
|
||||||
|
if modelName != *expectedModel {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid model: %s, expected: %s", modelName, *expectedModel)})
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, gin.H{"message": "ok"})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
r.POST("/v1/completions", func(c *gin.Context) {
|
r.POST("/v1/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "text/plain")
|
||||||
c.String(200, *responseMessage)
|
c.String(200, *responseMessage)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -224,11 +226,7 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
|
|||||||
defer pm.Unlock()
|
defer pm.Unlock()
|
||||||
|
|
||||||
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
// Check if requestedModel contains a PROFILE_SPLIT_CHAR
|
||||||
profileName, modelName := "", requestedModel
|
profileName, modelName := splitRequestedModel(requestedModel)
|
||||||
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
|
||||||
profileName = requestedModel[:idx]
|
|
||||||
modelName = requestedModel[idx+1:]
|
|
||||||
}
|
|
||||||
|
|
||||||
if profileName != "" {
|
if profileName != "" {
|
||||||
if _, found := pm.config.Profiles[profileName]; !found {
|
if _, found := pm.config.Profiles[profileName]; !found {
|
||||||
@@ -344,21 +342,26 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestBody map[string]interface{}
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||||
if err := json.Unmarshal(bodyBytes, &requestBody); err != nil {
|
if requestedModel == "" {
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("invalid JSON: %s", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
model, ok := requestBody["model"].(string)
|
|
||||||
if !ok {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if process, err := pm.swapModel(model); err != nil {
|
if process, err := pm.swapModel(requestedModel); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
|
// strip
|
||||||
|
profileName, modelName := splitRequestedModel(requestedModel)
|
||||||
|
if profileName != "" {
|
||||||
|
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
|
||||||
|
if err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
@@ -387,3 +390,12 @@ func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) {
|
|||||||
func ProcessKeyName(groupName, modelName string) string {
|
func ProcessKeyName(groupName, modelName string) string {
|
||||||
return groupName + PROFILE_SPLIT_CHAR + modelName
|
return groupName + PROFILE_SPLIT_CHAR + modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func splitRequestedModel(requestedModel string) (string, string) {
|
||||||
|
profileName, modelName := "", requestedModel
|
||||||
|
if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
|
||||||
|
profileName = requestedModel[:idx]
|
||||||
|
modelName = requestedModel[idx+1:]
|
||||||
|
}
|
||||||
|
return profileName, modelName
|
||||||
|
}
|
||||||
|
|||||||
@@ -326,3 +326,26 @@ func TestProxyManager_Unload(t *testing.T) {
|
|||||||
assert.Equal(t, w.Body.String(), "OK")
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
assert.Len(t, proxy.currentProcesses, 0)
|
assert.Len(t, proxy.currentProcesses, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// issue 62, strip profile slug from model name
|
||||||
|
func TestProxyManager_StripProfileSlug(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Profiles: map[string][]string{
|
||||||
|
"test": {"TheExpectedModel"}, // TheExpectedModel is default in simple-responder.go
|
||||||
|
},
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"TheExpectedModel": getTestSimpleResponderConfig("TheExpectedModel"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, "test:TheExpectedModel")
|
||||||
|
req := httptest.NewRequest("POST", "/v1/audio/speech", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
assert.Contains(t, w.Body.String(), "ok")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user