diff --git a/README.md b/README.md index 0418071..ce89476 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,13 @@ models: ghcr.io/ggerganov/llama.cpp:server --model '/models/Qwen2.5-Coder-0.5B-Instruct-Q4_K_M.gguf' + # `useModelName` will send a specific model name to the upstream server + # overriding whatever was set in the request + "qwq": + proxy: http://127.0.0.1:11434 + cmd: my-server + useModelName: "qwen:qwq" + # profiles make it easy to managing multi model (and gpu) configurations. # # Tips: diff --git a/proxy/config.go b/proxy/config.go index 3206ae9..56fedf6 100644 --- a/proxy/config.go +++ b/proxy/config.go @@ -17,6 +17,7 @@ type ModelConfig struct { CheckEndpoint string `yaml:"checkEndpoint"` UnloadAfter int `yaml:"ttl"` Unlisted bool `yaml:"unlisted"` + UseModelName string `yaml:"useModelName"` } func (m *ModelConfig) SanitizedCommand() ([]string, error) { diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 68eb798..7af29e1 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -96,7 +96,7 @@ func New(config *Config) *ProxyManager { // Support audio/speech endpoint pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) - pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIAudioTranscriptionHandler) + pm.ginEngine.POST("/v1/audio/transcriptions", pm.proxyOAIPostFormHandler) pm.ginEngine.GET("/v1/models", pm.listModelsHandler) @@ -351,12 +351,21 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") } - if process, err := pm.swapModel(requestedModel); err != nil { + process, err := pm.swapModel(requestedModel) + + if err != nil { pm.sendErrorResponse(c, http.StatusNotFound, fmt.Sprintf("unable to swap to model, %s", err.Error())) return - } else { + } - // strip + // issue #69 allow custom model names to be sent to upstream + if process.config.UseModelName != "" { + bodyBytes, err = sjson.SetBytes(bodyBytes, "model", process.config.UseModelName) + if err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error updating JSON: %s", err.Error())) + return + } + } else { profileName, modelName := splitRequestedModel(requestedModel) if profileName != "" { bodyBytes, err = sjson.SetBytes(bodyBytes, "model", modelName) @@ -366,17 +375,19 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { } } - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - - // dechunk it as we already have all the body bytes see issue #11 - c.Request.Header.Del("transfer-encoding") - c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) - - process.ProxyRequest(c.Writer, c.Request) } + + c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + // dechunk it as we already have all the body bytes see issue #11 + c.Request.Header.Del("transfer-encoding") + c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) + + process.ProxyRequest(c.Writer, c.Request) + } -func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) { +func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { // We need to reconstruct the multipart form in any case since the body is consumed // Create a new buffer for the reconstructed request var requestBuffer bytes.Buffer @@ -410,8 +421,12 @@ func (pm *ProxyManager) proxyOAIAudioTranscriptionHandler(c *gin.Context) { for _, value := range values { fieldValue := value // If this is the model field and we have a profile, use just the model name - if key == "model" && profileName != "" { - fieldValue = modelName + if key == "model" { + if process.config.UseModelName != "" { + fieldValue = process.config.UseModelName + } else if profileName != "" { + fieldValue = modelName + } } field, err := multipartWriter.CreateFormField(key) if err != nil { diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 470e8d7..a8c1576 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -532,3 +532,112 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { }) } } + +func TestProxyManager_SplitRequestedModel(t *testing.T) { + + tests := []struct { + name string + requestedModel string + expectedProfile string + expectedModel string + }{ + {"no profile", "gpt-4", "", "gpt-4"}, + {"with profile", "profile1:gpt-4", "profile1", "gpt-4"}, + {"only profile", "profile1:", "profile1", ""}, + {"empty model", ":gpt-4", "", "gpt-4"}, + {"empty profile", ":", "", ""}, + {"no split char", "gpt-4", "", "gpt-4"}, + {"profile and model with delimiter", "profile1:delimiter:gpt-4", "profile1", "delimiter:gpt-4"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profileName, modelName := splitRequestedModel(tt.requestedModel) + if profileName != tt.expectedProfile { + t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel) + } + if modelName != tt.expectedModel { + t.Errorf("splitRequestedModel(%q) = %q, %q; want %q, %q", tt.requestedModel, profileName, modelName, tt.expectedProfile, tt.expectedModel) + } + }) + } +} + +// Test useModelName in configuration sends overrides what is sent to upstream +func TestProxyManager_UseModelName(t *testing.T) { + + upstreamModelName := "upstreamModel" + + modelConfig := getTestSimpleResponderConfig(upstreamModelName) + modelConfig.UseModelName = upstreamModelName + + config := &Config{ + HealthCheckTimeout: 15, + Profiles: map[string][]string{ + "test": {"model1"}, + }, + + Models: map[string]ModelConfig{ + "model1": modelConfig, + }, + } + + proxy := New(config) + defer proxy.StopProcesses() + + tests := []struct { + description string + requestedModel string + }{ + {"useModelName over rides requested model", "model1"}, + {"useModelName over rides requested profile:model", "test:model1"}, + } + + for _, tt := range tests { + t.Run(tt.description+": /v1/chat/completions", func(t *testing.T) { + reqBody := fmt.Sprintf(`{"model":"%s"}`, tt.requestedModel) + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), upstreamModelName) + + }) + } + + for _, tt := range tests { + t.Run(tt.description+": /v1/audio/transcriptions", func(t *testing.T) { + // Create a buffer with multipart form data + var b bytes.Buffer + w := multipart.NewWriter(&b) + + // Add the model field + fw, err := w.CreateFormField("model") + assert.NoError(t, err) + _, err = fw.Write([]byte(tt.requestedModel)) + assert.NoError(t, err) + + // Add a file field + fw, err = w.CreateFormFile("file", "test.mp3") + assert.NoError(t, err) + _, err = fw.Write([]byte("test")) + assert.NoError(t, err) + w.Close() + + // Create the request with the multipart form data + req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b) + req.Header.Set("Content-Type", w.FormDataContentType()) + rec := httptest.NewRecorder() + proxy.HandlerFunc(rec, req) + + // Verify the response + assert.Equal(t, http.StatusOK, rec.Code) + var response map[string]string + err = json.Unmarshal(rec.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, upstreamModelName, response["model"]) + }) + } + +}