From abdc2bfdb370f0575945247c0fa92d4215a7eef9 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Thu, 16 Jan 2025 12:06:38 -0800 Subject: [PATCH] Fix panic when requesting non-members of profiles A panic occurs when a request for an invalid profile:model pair is made. The edge case is that the profile exists and the model exists but they're not configured as a pair. This adds an additional check to make sure the profile:model pair is valid before attempting to swap the model. --- proxy/proxymanager.go | 15 +++++++++++++ proxy/proxymanager_test.go | 44 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+) diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 53a8ba1..684dd37 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -202,6 +202,21 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) { return nil, fmt.Errorf("could not find modelID for %s", requestedModel) } + // check if model is part of the profile + if profileName != "" { + found := false + for _, item := range pm.config.Profiles[profileName] { + if item == realModelName { + found = true + break + } + } + + if !found { + return nil, fmt.Errorf("model %s part of profile %s", realModelName, profileName) + } + } + // exit early when already running, otherwise stop everything and swap requestedProcessKey := ProcessKeyName(profileName, realModelName) diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 1d96c14..e1476ef 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -210,3 +210,47 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { // Ensure all expected models were returned assert.Empty(t, expectedModels, "not all expected models were returned") } + +func TestProxyManager_ProfileNonMember(t *testing.T) { + + model1 := "path1/model1" + model2 := "path2/model2" + + profileMemberName := ProcessKeyName("test", model1) + profileNonMemberName := ProcessKeyName("test", model2) + + config := &Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + model1: getTestSimpleResponderConfig("model1"), + model2: getTestSimpleResponderConfig("model2"), + }, + Profiles: map[string][]string{ + "test": {model1}, + }, + } + + proxy := New(config) + defer proxy.StopProcesses() + + // actual member of profile + { + reqBody := fmt.Sprintf(`{"model":"%s"}`, profileMemberName) + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Contains(t, w.Body.String(), "model1") + } + + // actual model, but non-member will 404 + { + reqBody := fmt.Sprintf(`{"model":"%s"}`, profileNonMemberName) + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusNotFound, w.Code) + } +}