proxy: add unload of single model (#318)
This adds a new API endpoint, /api/models/unload/*model, that unloads a single model. In the UI when a model is in a ReadyState it will have a new button to unload it. Fixes #312
This commit is contained in:
@@ -86,6 +86,29 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||
pg.Lock()
|
||||
|
||||
process, exists := pg.processes[modelID]
|
||||
if !exists {
|
||||
pg.Unlock()
|
||||
return fmt.Errorf("process not found for %s", modelID)
|
||||
}
|
||||
|
||||
if pg.lastUsedProcess == modelID {
|
||||
pg.lastUsedProcess = ""
|
||||
}
|
||||
pg.Unlock()
|
||||
|
||||
switch strategy {
|
||||
case StopImmediately:
|
||||
process.StopImmediately()
|
||||
default:
|
||||
process.Stop()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||
pg.Lock()
|
||||
defer pg.Unlock()
|
||||
|
||||
@@ -228,7 +228,6 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
c.Redirect(http.StatusFound, "/ui/models")
|
||||
})
|
||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||
|
||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||
|
||||
@@ -3,8 +3,10 @@ package proxy
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
@@ -23,6 +25,7 @@ func addApiHandlers(pm *ProxyManager) {
|
||||
apiGroup := pm.ginEngine.Group("/api")
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
}
|
||||
@@ -202,3 +205,25 @@ func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonData)
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
||||
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup := pm.findGroupByModelName(realModelName)
|
||||
if processGroup == nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||
return
|
||||
}
|
||||
|
||||
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
|
||||
return
|
||||
} else {
|
||||
c.String(http.StatusOK, "OK")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -401,11 +401,65 @@ func TestProxyManager_Unload(t *testing.T) {
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.Equal(t, w.Body.String(), "OK")
|
||||
|
||||
// give it a bit of time to stop
|
||||
<-time.After(time.Millisecond * 250)
|
||||
select {
|
||||
case <-proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
|
||||
// good
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for model1 to stop")
|
||||
}
|
||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||
}
|
||||
|
||||
func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
||||
const testGroupId = "testGroup"
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
},
|
||||
Groups: map[string]GroupConfig{
|
||||
testGroupId: {
|
||||
Swap: false,
|
||||
Members: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopImmediately)
|
||||
|
||||
// start both model
|
||||
for _, modelName := range []string{"model1", "model2"} {
|
||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
}
|
||||
|
||||
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState())
|
||||
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
|
||||
w := httptest.NewRecorder()
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
if !assert.Equal(t, w.Body.String(), "OK") {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan:
|
||||
// good
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timeout waiting for model1 to stop")
|
||||
}
|
||||
|
||||
assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped)
|
||||
assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady)
|
||||
}
|
||||
|
||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||
// Shared configuration
|
||||
|
||||
Reference in New Issue
Block a user