change profile split character to : (colon) (#21)

- change from `/` to `:` for multiple models loaded as part of a profile
- breaking change now, but allows for more compatibility with other inference engines that may have model references like `coding:Qwen/Qwen-2.5-Coder-32B`
This commit is contained in:
Benson Wong
2024-12-01 09:10:50 -08:00
committed by GitHub
parent 9fc5d5b5eb
commit 04b4760e7e
2 changed files with 40 additions and 21 deletions

View File

@@ -14,6 +14,10 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
const (
PROFILE_SPLIT_CHAR = ":"
)
type ProxyManager struct { type ProxyManager struct {
sync.Mutex sync.Mutex
@@ -106,15 +110,15 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
defer pm.Unlock() defer pm.Unlock()
// Check if requestedModel contains a / // Check if requestedModel contains a /
groupName, modelName := "", requestedModel profileName, modelName := "", requestedModel
if idx := strings.Index(requestedModel, "/"); idx != -1 { if idx := strings.Index(requestedModel, PROFILE_SPLIT_CHAR); idx != -1 {
groupName = requestedModel[:idx] profileName = requestedModel[:idx]
modelName = requestedModel[idx+1:] modelName = requestedModel[idx+1:]
} }
if groupName != "" { if profileName != "" {
if _, found := pm.config.Profiles[groupName]; !found { if _, found := pm.config.Profiles[profileName]; !found {
return nil, fmt.Errorf("model group not found %s", groupName) return nil, fmt.Errorf("model group not found %s", profileName)
} }
} }
@@ -125,7 +129,8 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
} }
// exit early when already running, otherwise stop everything and swap // exit early when already running, otherwise stop everything and swap
requestedProcessKey := groupName + "/" + realModelName requestedProcessKey := ProcessKeyName(profileName, realModelName)
if process, found := pm.currentProcesses[requestedProcessKey]; found { if process, found := pm.currentProcesses[requestedProcessKey]; found {
return process, nil return process, nil
} }
@@ -133,25 +138,25 @@ func (pm *ProxyManager) swapModel(requestedModel string) (*Process, error) {
// stop all running models // stop all running models
pm.stopProcesses() pm.stopProcesses()
if groupName == "" { if profileName == "" {
modelConfig, modelID, found := pm.config.FindConfig(realModelName) modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found { if !found {
return nil, fmt.Errorf("could not find configuration for %s", realModelName) return nil, fmt.Errorf("could not find configuration for %s", realModelName)
} }
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor) process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := groupName + "/" + modelID processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process pm.currentProcesses[processKey] = process
} else { } else {
for _, modelName := range pm.config.Profiles[groupName] { for _, modelName := range pm.config.Profiles[profileName] {
if realModelName, found := pm.config.RealModelName(modelName); found { if realModelName, found := pm.config.RealModelName(modelName); found {
modelConfig, modelID, found := pm.config.FindConfig(realModelName) modelConfig, modelID, found := pm.config.FindConfig(realModelName)
if !found { if !found {
return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, groupName) return nil, fmt.Errorf("could not find configuration for %s in group %s", realModelName, profileName)
} }
process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor) process := NewProcess(modelID, pm.config.HealthCheckTimeout, modelConfig, pm.logMonitor)
processKey := groupName + "/" + modelID processKey := ProcessKeyName(profileName, modelID)
pm.currentProcesses[processKey] = process pm.currentProcesses[processKey] = process
} }
} }
@@ -201,3 +206,7 @@ func (pm *ProxyManager) proxyNoRouteHandler(c *gin.Context) {
c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request")) c.AbortWithError(http.StatusBadRequest, fmt.Errorf("no strategy to handle request"))
} }
func ProcessKeyName(groupName, modelName string) string {
return groupName + PROFILE_SPLIT_CHAR + modelName
}

View File

@@ -33,7 +33,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Contains(t, w.Body.String(), modelName) assert.Contains(t, w.Body.String(), modelName)
_, exists := proxy.currentProcesses["/"+modelName] _, exists := proxy.currentProcesses[ProcessKeyName("", modelName)]
assert.True(t, exists, "expected %s key in currentProcesses", modelName) assert.True(t, exists, "expected %s key in currentProcesses", modelName)
} }
@@ -43,21 +43,31 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
} }
func TestProxyManager_SwapMultiProcess(t *testing.T) { func TestProxyManager_SwapMultiProcess(t *testing.T) {
model1 := "path1/model1"
model2 := "path2/model2"
profileModel1 := ProcessKeyName("test", model1)
profileModel2 := ProcessKeyName("test", model2)
config := &Config{ config := &Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,
Models: map[string]ModelConfig{ Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"), model1: getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"), model2: getTestSimpleResponderConfig("model2"),
}, },
Profiles: map[string][]string{ Profiles: map[string][]string{
"test": {"model1", "model2"}, "test": {model1, model2},
}, },
} }
proxy := New(config) proxy := New(config)
defer proxy.StopProcesses() defer proxy.StopProcesses()
for modelID, requestedModel := range map[string]string{"model1": "test/model1", "model2": "test/model2"} { for modelID, requestedModel := range map[string]string{
"model1": profileModel1,
"model2": profileModel2,
} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel) reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder() w := httptest.NewRecorder()
@@ -69,11 +79,11 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
// make sure there's two loaded models // make sure there's two loaded models
assert.Len(t, proxy.currentProcesses, 2) assert.Len(t, proxy.currentProcesses, 2)
_, exists := proxy.currentProcesses["test/model1"] _, exists := proxy.currentProcesses[profileModel1]
assert.True(t, exists, "expected test/model1 key in currentProcesses") assert.True(t, exists, "expected "+profileModel1+" key in currentProcesses")
_, exists = proxy.currentProcesses["test/model2"] _, exists = proxy.currentProcesses[profileModel2]
assert.True(t, exists, "expected test/model2 key in currentProcesses") assert.True(t, exists, "expected "+profileModel2+" key in currentProcesses")
} }
// When a request for a different model comes in ProxyManager should wait until // When a request for a different model comes in ProxyManager should wait until