diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index d81c716..602fe41 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -104,6 +104,8 @@ func New(config *Config) *ProxyManager { pm.ginEngine.GET("/upstream", pm.upstreamIndex) pm.ginEngine.Any("/upstream/:model_id/*upstreamPath", pm.proxyToUpstream) + pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) + pm.ginEngine.GET("/", func(c *gin.Context) { // Set the Content-Type header to text/html c.Header("Content-Type", "text/html") @@ -377,6 +379,11 @@ func (pm *ProxyManager) sendErrorResponse(c *gin.Context, statusCode int, messag } } +func (pm *ProxyManager) unloadAllModelsHandler(c *gin.Context) { + pm.StopProcesses() + c.String(http.StatusOK, "OK") +} + func ProcessKeyName(groupName, modelName string) string { return groupName + PROFILE_SPLIT_CHAR + modelName } diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index f15e5f7..f167443 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -304,3 +304,25 @@ func TestProxyManager_Shutdown(t *testing.T) { }() wg.Wait() } + +func TestProxyManager_Unload(t *testing.T) { + config := &Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + }, + } + + proxy := New(config) + proc, err := proxy.swapModel("model1") + assert.NoError(t, err) + assert.NotNil(t, proc) + + assert.Len(t, proxy.currentProcesses, 1) + req := httptest.NewRequest("GET", "/unload", nil) + w := httptest.NewRecorder() + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, w.Body.String(), "OK") + assert.Len(t, proxy.currentProcesses, 0) +}