diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index bdc31c7..3f123f6 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -361,7 +361,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { return } - processGroup, _, err := pm.swapProcessGroup(requestedModel) + processGroup, realModelName, err := pm.swapProcessGroup(requestedModel) if err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) return @@ -369,7 +369,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { // rewrite the path c.Request.URL.Path = c.Param("upstreamPath") - processGroup.ProxyRequest(requestedModel, c.Writer, c.Request) + processGroup.ProxyRequest(realModelName, c.Writer, c.Request) } func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 2447ad9..0959eea 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "strconv" + "strings" "sync" "testing" "time" @@ -280,48 +281,48 @@ func TestProxyManager_ListModelsHandler(t *testing.T) { } func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) { - // Intentionally add models in non-sorted order and with an unlisted model - config := Config{ - HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ - "zeta": getTestSimpleResponderConfig("zeta"), - "alpha": getTestSimpleResponderConfig("alpha"), - "beta": getTestSimpleResponderConfig("beta"), - "hidden": func() ModelConfig { - mc := getTestSimpleResponderConfig("hidden") - mc.Unlisted = true - return mc - }(), - }, - LogLevel: "error", - } + // Intentionally add models in non-sorted order and with an unlisted model + config := Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "zeta": getTestSimpleResponderConfig("zeta"), + "alpha": getTestSimpleResponderConfig("alpha"), + "beta": getTestSimpleResponderConfig("beta"), + "hidden": func() ModelConfig { + mc := getTestSimpleResponderConfig("hidden") + mc.Unlisted = true + return mc + }(), + }, + LogLevel: "error", + } - proxy := New(config) + proxy := New(config) - // Request models list - req := httptest.NewRequest("GET", "/v1/models", nil) - w := httptest.NewRecorder() - proxy.ServeHTTP(w, req) + // Request models list + req := httptest.NewRequest("GET", "/v1/models", nil) + w := httptest.NewRecorder() + proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, http.StatusOK, w.Code) - var response struct { - Data []map[string]interface{} `json:"data"` - } - if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { - t.Fatalf("Failed to parse JSON response: %v", err) - } + var response struct { + Data []map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to parse JSON response: %v", err) + } - // We expect only the listed models in sorted order by id - expectedOrder := []string{"alpha", "beta", "zeta"} - if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") { - got := make([]string, 0, len(response.Data)) - for _, m := range response.Data { - id, _ := m["id"].(string) - got = append(got, id) - } - assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending") - } + // We expect only the listed models in sorted order by id + expectedOrder := []string{"alpha", "beta", "zeta"} + if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") { + got := make([]string, 0, len(response.Data)) + for _, m := range response.Data { + id, _ := m["id"].(string) + got = append(got, id) + } + assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending") + } } func TestProxyManager_Shutdown(t *testing.T) { @@ -656,21 +657,34 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) { } func TestProxyManager_Upstream(t *testing.T) { - config := AddDefaultGroupToConfig(Config{ - HealthCheckTimeout: 15, - Models: map[string]ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) + configStr := fmt.Sprintf(` +logLevel: error +models: + model1: + cmd: %s -port ${PORT} -silent -respond model1 + aliases: [model-alias] +`, getSimpleResponderPath()) + + config, err := LoadConfigFromReader(strings.NewReader(configStr)) + assert.NoError(t, err) proxy := New(config) defer proxy.StopProcesses(StopWaitForInflightRequest) - req := httptest.NewRequest("GET", "/upstream/model1/test", nil) - rec := httptest.NewRecorder() - proxy.ServeHTTP(rec, req) - assert.Equal(t, http.StatusOK, rec.Code) - assert.Equal(t, "model1", rec.Body.String()) + t.Run("main model name", func(t *testing.T) { + req := httptest.NewRequest("GET", "/upstream/model1/test", nil) + rec := httptest.NewRecorder() + proxy.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "model1", rec.Body.String()) + }) + + t.Run("model alias", func(t *testing.T) { + req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil) + rec := httptest.NewRecorder() + proxy.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code) + assert.Equal(t, "model1", rec.Body.String()) + }) } func TestProxyManager_ChatContentLength(t *testing.T) {