Fix model alias usage in upstream path (#230)

Model alias values are not properly resolved and work in upstream/ path.

Related to #229.
This commit is contained in:
Benson Wong
2025-08-07 20:16:56 -07:00
committed by GitHub
parent 5b10b3c23f
commit 10569ed546
2 changed files with 65 additions and 51 deletions

View File

@@ -361,7 +361,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
return return
} }
processGroup, _, err := pm.swapProcessGroup(requestedModel) processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
if err != nil { if err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
return return
@@ -369,7 +369,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
// rewrite the path // rewrite the path
c.Request.URL.Path = c.Param("upstreamPath") c.Request.URL.Path = c.Param("upstreamPath")
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request) processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
} }
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {

View File

@@ -9,6 +9,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strconv" "strconv"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -280,48 +281,48 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
} }
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) { func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
// Intentionally add models in non-sorted order and with an unlisted model // Intentionally add models in non-sorted order and with an unlisted model
config := Config{ config := Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"zeta": getTestSimpleResponderConfig("zeta"), "zeta": getTestSimpleResponderConfig("zeta"),
"alpha": getTestSimpleResponderConfig("alpha"), "alpha": getTestSimpleResponderConfig("alpha"),
"beta": getTestSimpleResponderConfig("beta"), "beta": getTestSimpleResponderConfig("beta"),
"hidden": func() ModelConfig { "hidden": func() ModelConfig {
mc := getTestSimpleResponderConfig("hidden") mc := getTestSimpleResponderConfig("hidden")
mc.Unlisted = true mc.Unlisted = true
return mc return mc
}(), }(),
}, },
LogLevel: "error", LogLevel: "error",
} }
proxy := New(config) proxy := New(config)
// Request models list // Request models list
req := httptest.NewRequest("GET", "/v1/models", nil) req := httptest.NewRequest("GET", "/v1/models", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
proxy.ServeHTTP(w, req) proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
var response struct { var response struct {
Data []map[string]interface{} `json:"data"` Data []map[string]interface{} `json:"data"`
} }
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse JSON response: %v", err) t.Fatalf("Failed to parse JSON response: %v", err)
} }
// We expect only the listed models in sorted order by id // We expect only the listed models in sorted order by id
expectedOrder := []string{"alpha", "beta", "zeta"} expectedOrder := []string{"alpha", "beta", "zeta"}
if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") { if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") {
got := make([]string, 0, len(response.Data)) got := make([]string, 0, len(response.Data))
for _, m := range response.Data { for _, m := range response.Data {
id, _ := m["id"].(string) id, _ := m["id"].(string)
got = append(got, id) got = append(got, id)
} }
assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending") assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending")
} }
} }
func TestProxyManager_Shutdown(t *testing.T) { func TestProxyManager_Shutdown(t *testing.T) {
@@ -656,21 +657,34 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
} }
func TestProxyManager_Upstream(t *testing.T) { func TestProxyManager_Upstream(t *testing.T) {
config := AddDefaultGroupToConfig(Config{ configStr := fmt.Sprintf(`
HealthCheckTimeout: 15, logLevel: error
Models: map[string]ModelConfig{ models:
"model1": getTestSimpleResponderConfig("model1"), model1:
}, cmd: %s -port ${PORT} -silent -respond model1
LogLevel: "error", aliases: [model-alias]
}) `, getSimpleResponderPath())
config, err := LoadConfigFromReader(strings.NewReader(configStr))
assert.NoError(t, err)
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest) defer proxy.StopProcesses(StopWaitForInflightRequest)
req := httptest.NewRequest("GET", "/upstream/model1/test", nil) t.Run("main model name", func(t *testing.T) {
rec := httptest.NewRecorder() req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
proxy.ServeHTTP(rec, req) rec := httptest.NewRecorder()
assert.Equal(t, http.StatusOK, rec.Code) proxy.ServeHTTP(rec, req)
assert.Equal(t, "model1", rec.Body.String()) assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
})
t.Run("model alias", func(t *testing.T) {
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
rec := httptest.NewRecorder()
proxy.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
})
} }
func TestProxyManager_ChatContentLength(t *testing.T) { func TestProxyManager_ChatContentLength(t *testing.T) {