Fix model alias usage in upstream path (#230)
Model alias values are not properly resolved and work in upstream/ path. Related to #229.
This commit is contained in:
@@ -361,7 +361,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
processGroup, _, err := pm.swapProcessGroup(requestedModel)
|
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||||
return
|
return
|
||||||
@@ -369,7 +369,7 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
|
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
c.Request.URL.Path = c.Param("upstreamPath")
|
c.Request.URL.Path = c.Param("upstreamPath")
|
||||||
processGroup.ProxyRequest(requestedModel, c.Writer, c.Request)
|
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -280,48 +281,48 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
||||||
// Intentionally add models in non-sorted order and with an unlisted model
|
// Intentionally add models in non-sorted order and with an unlisted model
|
||||||
config := Config{
|
config := Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"zeta": getTestSimpleResponderConfig("zeta"),
|
"zeta": getTestSimpleResponderConfig("zeta"),
|
||||||
"alpha": getTestSimpleResponderConfig("alpha"),
|
"alpha": getTestSimpleResponderConfig("alpha"),
|
||||||
"beta": getTestSimpleResponderConfig("beta"),
|
"beta": getTestSimpleResponderConfig("beta"),
|
||||||
"hidden": func() ModelConfig {
|
"hidden": func() ModelConfig {
|
||||||
mc := getTestSimpleResponderConfig("hidden")
|
mc := getTestSimpleResponderConfig("hidden")
|
||||||
mc.Unlisted = true
|
mc.Unlisted = true
|
||||||
return mc
|
return mc
|
||||||
}(),
|
}(),
|
||||||
},
|
},
|
||||||
LogLevel: "error",
|
LogLevel: "error",
|
||||||
}
|
}
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
|
|
||||||
// Request models list
|
// Request models list
|
||||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
var response struct {
|
var response struct {
|
||||||
Data []map[string]interface{} `json:"data"`
|
Data []map[string]interface{} `json:"data"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// We expect only the listed models in sorted order by id
|
// We expect only the listed models in sorted order by id
|
||||||
expectedOrder := []string{"alpha", "beta", "zeta"}
|
expectedOrder := []string{"alpha", "beta", "zeta"}
|
||||||
if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") {
|
if assert.Len(t, response.Data, len(expectedOrder), "unexpected number of listed models") {
|
||||||
got := make([]string, 0, len(response.Data))
|
got := make([]string, 0, len(response.Data))
|
||||||
for _, m := range response.Data {
|
for _, m := range response.Data {
|
||||||
id, _ := m["id"].(string)
|
id, _ := m["id"].(string)
|
||||||
got = append(got, id)
|
got = append(got, id)
|
||||||
}
|
}
|
||||||
assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending")
|
assert.Equal(t, expectedOrder, got, "models should be sorted by id ascending")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_Shutdown(t *testing.T) {
|
func TestProxyManager_Shutdown(t *testing.T) {
|
||||||
@@ -656,21 +657,34 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_Upstream(t *testing.T) {
|
func TestProxyManager_Upstream(t *testing.T) {
|
||||||
config := AddDefaultGroupToConfig(Config{
|
configStr := fmt.Sprintf(`
|
||||||
HealthCheckTimeout: 15,
|
logLevel: error
|
||||||
Models: map[string]ModelConfig{
|
models:
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
model1:
|
||||||
},
|
cmd: %s -port ${PORT} -silent -respond model1
|
||||||
LogLevel: "error",
|
aliases: [model-alias]
|
||||||
})
|
`, getSimpleResponderPath())
|
||||||
|
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(configStr))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
t.Run("main model name", func(t *testing.T) {
|
||||||
rec := httptest.NewRecorder()
|
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||||
proxy.ServeHTTP(rec, req)
|
rec := httptest.NewRecorder()
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
proxy.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, "model1", rec.Body.String())
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("model alias", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(rec, req)
|
||||||
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_ChatContentLength(t *testing.T) {
|
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user