From 1a8492650508d97ff6f179b55aa0b123790913e4 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Wed, 24 Sep 2025 20:53:48 -0700 Subject: [PATCH] proxy: add unload of single model (#318) This adds a new API endpoint, /api/models/unload/*model, that unloads a single model. In the UI when a model is in a ReadyState it will have a new button to unload it. Fixes #312 --- proxy/processgroup.go | 23 +++++++++++++ proxy/proxymanager.go | 1 - proxy/proxymanager_api.go | 25 ++++++++++++++ proxy/proxymanager_test.go | 58 +++++++++++++++++++++++++++++++-- ui/src/contexts/APIProvider.tsx | 18 +++++++++- ui/src/pages/Models.tsx | 26 +++++++++------ ui/vite.config.ts | 1 + 7 files changed, 138 insertions(+), 14 deletions(-) diff --git a/proxy/processgroup.go b/proxy/processgroup.go index cca48c1..e4ce63e 100644 --- a/proxy/processgroup.go +++ b/proxy/processgroup.go @@ -86,6 +86,29 @@ func (pg *ProcessGroup) HasMember(modelName string) bool { return slices.Contains(pg.config.Groups[pg.id].Members, modelName) } +func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error { + pg.Lock() + + process, exists := pg.processes[modelID] + if !exists { + pg.Unlock() + return fmt.Errorf("process not found for %s", modelID) + } + + if pg.lastUsedProcess == modelID { + pg.lastUsedProcess = "" + } + pg.Unlock() + + switch strategy { + case StopImmediately: + process.StopImmediately() + default: + process.Stop() + } + return nil +} + func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) { pg.Lock() defer pg.Unlock() diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 9383f06..e485fff 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -228,7 +228,6 @@ func (pm *ProxyManager) setupGinEngine() { c.Redirect(http.StatusFound, "/ui/models") }) pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream) - pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) pm.ginEngine.GET("/running", pm.listRunningProcessesHandler) pm.ginEngine.GET("/health", func(c *gin.Context) { diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go index 19460ea..7100e2c 100644 --- a/proxy/proxymanager_api.go +++ b/proxy/proxymanager_api.go @@ -3,8 +3,10 @@ package proxy import ( "context" "encoding/json" + "fmt" "net/http" "sort" + "strings" "github.com/gin-gonic/gin" "github.com/mostlygeek/llama-swap/event" @@ -23,6 +25,7 @@ func addApiHandlers(pm *ProxyManager) { apiGroup := pm.ginEngine.Group("/api") { apiGroup.POST("/models/unload", pm.apiUnloadAllModels) + apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler) apiGroup.GET("/events", pm.apiSendEvents) apiGroup.GET("/metrics", pm.apiGetMetrics) } @@ -202,3 +205,25 @@ func (pm *ProxyManager) apiGetMetrics(c *gin.Context) { } c.Data(http.StatusOK, "application/json", jsonData) } + +func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) { + requestedModel := strings.TrimPrefix(c.Param("model"), "/") + realModelName, found := pm.config.RealModelName(requestedModel) + if !found { + pm.sendErrorResponse(c, http.StatusNotFound, "Model not found") + return + } + + processGroup := pm.findGroupByModelName(realModelName) + if processGroup == nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel)) + return + } + + if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error())) + return + } else { + c.String(http.StatusOK, "OK") + } +} diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index e5cea5e..446a095 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -401,11 +401,65 @@ func TestProxyManager_Unload(t *testing.T) { assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, w.Body.String(), "OK") - // give it a bit of time to stop - <-time.After(time.Millisecond * 250) + select { + case <-proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan: + // good + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for model1 to stop") + } assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped) } +func TestProxyManager_UnloadSingleModel(t *testing.T) { + const testGroupId = "testGroup" + config := AddDefaultGroupToConfig(Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), + }, + Groups: map[string]GroupConfig{ + testGroupId: { + Swap: false, + Members: []string{"model1", "model2"}, + }, + }, + LogLevel: "error", + }) + + proxy := New(config) + defer proxy.StopProcesses(StopImmediately) + + // start both model + for _, modelName := range []string{"model1", "model2"} { + reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName) + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + proxy.ServeHTTP(w, req) + } + + assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState()) + assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState()) + + req := httptest.NewRequest("POST", "/api/models/unload/model1", nil) + w := httptest.NewRecorder() + proxy.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + if !assert.Equal(t, w.Body.String(), "OK") { + t.FailNow() + } + + select { + case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan: + // good + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for model1 to stop") + } + + assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped) + assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady) +} + // Test issue #61 `Listing the current list of models and the loaded model.` func TestProxyManager_RunningEndpoint(t *testing.T) { // Shared configuration diff --git a/ui/src/contexts/APIProvider.tsx b/ui/src/contexts/APIProvider.tsx index aebc625..b8708b9 100644 --- a/ui/src/contexts/APIProvider.tsx +++ b/ui/src/contexts/APIProvider.tsx @@ -16,6 +16,7 @@ interface APIProviderType { models: Model[]; listModels: () => Promise; unloadAllModels: () => Promise; + unloadSingleModel: (model: string) => Promise; loadModel: (model: string) => Promise; enableAPIEvents: (enabled: boolean) => void; proxyLogs: string; @@ -177,7 +178,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider const unloadAllModels = useCallback(async () => { try { - const response = await fetch(`/api/models/unload/`, { + const response = await fetch(`/api/models/unload`, { method: "POST", }); if (!response.ok) { @@ -189,6 +190,20 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider } }, []); + const unloadSingleModel = useCallback(async (model: string) => { + try { + const response = await fetch(`/api/models/unload/${model}`, { + method: "POST", + }); + if (!response.ok) { + throw new Error(`Failed to unload model: ${response.status}`); + } + } catch (error) { + console.error("Failed to unload model", model, error); + throw error; + } + }, []); + const loadModel = useCallback(async (model: string) => { try { const response = await fetch(`/upstream/${model}/`, { @@ -208,6 +223,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider models, listModels, unloadAllModels, + unloadSingleModel, loadModel, enableAPIEvents, proxyLogs, diff --git a/ui/src/pages/Models.tsx b/ui/src/pages/Models.tsx index cd205bb..e2a206b 100644 --- a/ui/src/pages/Models.tsx +++ b/ui/src/pages/Models.tsx @@ -4,7 +4,7 @@ import { LogPanel } from "./LogViewer"; import { usePersistentState } from "../hooks/usePersistentState"; import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels"; import { useTheme } from "../contexts/ThemeProvider"; -import { RiEyeFill, RiEyeOffFill, RiStopCircleLine, RiSwapBoxFill } from "react-icons/ri"; +import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine } from "react-icons/ri"; export default function ModelsPage() { const { isNarrow } = useTheme(); @@ -37,7 +37,7 @@ export default function ModelsPage() { } function ModelsPanel() { - const { models, loadModel, unloadAllModels } = useAPI(); + const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI(); const [isUnloading, setIsUnloading] = useState(false); const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true); const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name @@ -90,7 +90,7 @@ function ModelsPanel() { onClick={handleUnloadAllModels} disabled={isUnloading} > - {isUnloading ? "Unloading..." : "Unload"} + {isUnloading ? "Unloading..." : "Unload All"} @@ -119,13 +119,19 @@ function ModelsPanel() { )} - + {model.state === "stopped" ? ( + + ) : ( + + )} {model.state} diff --git a/ui/vite.config.ts b/ui/vite.config.ts index 1a18997..dd0e4f3 100644 --- a/ui/vite.config.ts +++ b/ui/vite.config.ts @@ -15,6 +15,7 @@ export default defineConfig({ "/api": "http://localhost:8080", // Proxy API calls to Go backend during development "/logs": "http://localhost:8080", "/upstream": "http://localhost:8080", + "/unload": "http://localhost:8080", }, }, });