Add support for sending a custom model name to upstream (#69) (#71)

* add test for splitRequestedModel()
* Add `useModelName` parameter to model configuration
* add docs to README
This commit is contained in:
Benson Wong
2025-03-14 21:07:52 -07:00
committed by GitHub
parent 671c1a5a7b
commit 5c97299e7b
4 changed files with 146 additions and 14 deletions

View File

@@ -117,6 +117,13 @@ models:
ghcr.io/ggerganov/llama.cpp:server
--model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf'
# `useModelName` will send a specific model name to the upstream server
# overriding whatever was set in the request
"qwq":
proxy: http://127.0.0.1:11434
cmd: my-server
useModelName: "qwen:qwq"
# profiles make it easy to managing multi model (and gpu) configurations.
#
# Tips:

View File

@@ -17,6 +17,7 @@ type ModelConfig struct {
CheckEndpoint string `yaml:"checkEndpoint"`
UnloadAfter int `yaml:"ttl"`
Unlisted bool `yaml:"unlisted"`
UseModelName string `yaml:"useModelName"`
}
func (m *ModelConfig) SanitizedCommand() ([]string, error) {

View File

@@ -96,7 +96,7 @@ func New(config *Config) *ProxyManager {
// Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIAudioTranscriptionHandler)
pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler)
pm.ginEngine.GET("/v1/models", pm.listModelsHandler)
@@ -351,12 +351,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
}
if process, err := pm.swapModel(requestedModel); err != nil {
process, err := pm.swapModel(requestedModel)
if err != nil {
pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error()))
return
} else {
}
// strip
// issue #69 allow custom model names to be sent to upstream
if process.config.UseModelName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error()))
return
}
} else {
profileName, modelName := splitRequestedModel(requestedModel)
if profileName != "" {
bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName)
@@ -366,6 +375,8 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
}
}
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
// dechunk it as we already have all the body bytes see issue #11
@@ -373,10 +384,10 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
process.ProxyRequest(c.Writer, c.Request)
}
}
func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) {
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
@@ -410,9 +421,13 @@ func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) {
for _, value := range values {
fieldValue := value
// If this is the model field and we have a profile, use just the model name
if key == "model" && profileName != "" {
if key == "model" {
if process.config.UseModelName != "" {
fieldValue = process.config.UseModelName
} else if profileName != "" {
fieldValue = modelName
}
}
field, err := multipartWriter.CreateFormField(key)
if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, "error recreating form field")

View File

@@ -532,3 +532,112 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
})
}
}
func TestProxyManager_SplitRequestedModel(t *testing.T) {
tests := []struct {
name string
requestedModel string
expectedProfile string
expectedModel string
}{
{"no profile", "gpt-4", "", "gpt-4"},
{"with profile", "profile1:gpt-4", "profile1", "gpt-4"},
{"only profile", "profile1:", "profile1", ""},
{"empty model", ":gpt-4", "", "gpt-4"},
{"empty profile", ":", "", ""},
{"no split char", "gpt-4", "", "gpt-4"},
{"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
profileName, modelName := splitRequestedModel(tt.requestedModel)
if profileName != tt.expectedProfile {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
if modelName != tt.expectedModel {
t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel)
}
})
}
}
// Test useModelName in configuration sends overrides what is sent to upstream
func TestProxyManager_UseModelName(t *testing.T) {
upstreamModelName := "upstreamModel"
modelConfig := getTestSimpleResponderConfig(upstreamModelName)
modelConfig.UseModelName = upstreamModelName
config := &Config{
HealthCheckTimeout: 15,
Profiles: map[string][]string{
"test": {"model1"},
},
Models: map[string]ModelConfig{
"model1": modelConfig,
},
}
proxy := New(config)
defer proxy.StopProcesses()
tests := []struct {
description string
requestedModel string
}{
{"useModelName over rides requested model", "model1"},
{"useModelName over rides requested profile:model", "test:model1"},
}
for _, tt := range tests {
t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) {
reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), upstreamModelName)
})
}
for _, tt := range tests {
t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) {
// Create a buffer with multipart form data
var b bytes.Buffer
w := multipart.NewWriter(&b)
// Add the model field
fw, err := w.CreateFormField("model")
assert.NoError(t, err)
_, err = fw.Write([]byte(tt.requestedModel))
assert.NoError(t, err)
// Add a file field
fw, err = w.CreateFormFile("file", "test.mp3")
assert.NoError(t, err)
_, err = fw.Write([]byte("test"))
assert.NoError(t, err)
w.Close()
// Create the request with the multipart form data
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
req.Header.Set("Content-Type", w.FormDataContentType())
rec := httptest.NewRecorder()
proxy.HandlerFunc(rec, req)
// Verify the response
assert.Equal(t, http.StatusOK, rec.Code)
var response map[string]string
err = json.Unmarshal(rec.Body.Bytes(), &response)
assert.NoError(t, err)
assert.Equal(t, upstreamModelName, response["model"])
})
}
}