proxy: add unload of single model (#318)

This adds a new API endpoint, /api/models/unload/*model, that unloads a single model. In the UI when a model is in a ReadyState it will have a new button to unload it. 

Fixes #312
This commit is contained in:
Benson Wong
2025-09-24 20:53:48 -07:00
committed by GitHub
parent fc3bb716df
commit 1a84926505
7 changed files with 138 additions and 14 deletions

View File

@@ -86,6 +86,29 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
return slices.Contains(pg.config.Groups[pg.id].Members, modelName) return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
} }
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
pg.Lock()
process, exists := pg.processes[modelID]
if !exists {
pg.Unlock()
return fmt.Errorf("process not found for %s", modelID)
}
if pg.lastUsedProcess == modelID {
pg.lastUsedProcess = ""
}
pg.Unlock()
switch strategy {
case StopImmediately:
process.StopImmediately()
default:
process.Stop()
}
return nil
}
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) { func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
pg.Lock() pg.Lock()
defer pg.Unlock() defer pg.Unlock()

View File

@@ -228,7 +228,6 @@ func (pm *ProxyManager) setupGinEngine() {
c.Redirect(http.StatusFound, "/ui/models") c.Redirect(http.StatusFound, "/ui/models")
}) })
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream) pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler) pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler) pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
pm.ginEngine.GET("/health", func(c *gin.Context) { pm.ginEngine.GET("/health", func(c *gin.Context) {

View File

@@ -3,8 +3,10 @@ package proxy
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"sort" "sort"
"strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
@@ -23,6 +25,7 @@ func addApiHandlers(pm *ProxyManager) {
apiGroup := pm.ginEngine.Group("/api") apiGroup := pm.ginEngine.Group("/api")
{ {
apiGroup.POST("/models/unload", pm.apiUnloadAllModels) apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
apiGroup.GET("/events", pm.apiSendEvents) apiGroup.GET("/events", pm.apiSendEvents)
apiGroup.GET("/metrics", pm.apiGetMetrics) apiGroup.GET("/metrics", pm.apiGetMetrics)
} }
@@ -202,3 +205,25 @@ func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
} }
c.Data(http.StatusOK, "application/json", jsonData) c.Data(http.StatusOK, "application/json", jsonData)
} }
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
return
}
processGroup := pm.findGroupByModelName(realModelName)
if processGroup == nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
return
}
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
return
} else {
c.String(http.StatusOK, "OK")
}
}

View File

@@ -401,11 +401,65 @@ func TestProxyManager_Unload(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, w.Body.String(), "OK") assert.Equal(t, w.Body.String(), "OK")
// give it a bit of time to stop select {
<-time.After(time.Millisecond * 250) case <-proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
// good
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for model1 to stop")
}
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped) assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
} }
func TestProxyManager_UnloadSingleModel(t *testing.T) {
const testGroupId = "testGroup"
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
},
Groups: map[string]GroupConfig{
testGroupId: {
Swap: false,
Members: []string{"model1", "model2"},
},
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopImmediately)
// start both model
for _, modelName := range []string{"model1", "model2"} {
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
}
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState())
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
w := httptest.NewRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
if !assert.Equal(t, w.Body.String(), "OK") {
t.FailNow()
}
select {
case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan:
// good
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for model1 to stop")
}
assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped)
assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady)
}
// Test issue #61 `Listing the current list of models and the loaded model.` // Test issue #61 `Listing the current list of models and the loaded model.`
func TestProxyManager_RunningEndpoint(t *testing.T) { func TestProxyManager_RunningEndpoint(t *testing.T) {
// Shared configuration // Shared configuration

View File

@@ -16,6 +16,7 @@ interface APIProviderType {
models: Model[]; models: Model[];
listModels: () => Promise<Model[]>; listModels: () => Promise<Model[]>;
unloadAllModels: () => Promise<void>; unloadAllModels: () => Promise<void>;
unloadSingleModel: (model: string) => Promise<void>;
loadModel: (model: string) => Promise<void>; loadModel: (model: string) => Promise<void>;
enableAPIEvents: (enabled: boolean) => void; enableAPIEvents: (enabled: boolean) => void;
proxyLogs: string; proxyLogs: string;
@@ -177,7 +178,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
const unloadAllModels = useCallback(async () => { const unloadAllModels = useCallback(async () => {
try { try {
const response = await fetch(`/api/models/unload/`, { const response = await fetch(`/api/models/unload`, {
method: "POST", method: "POST",
}); });
if (!response.ok) { if (!response.ok) {
@@ -189,6 +190,20 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
} }
}, []); }, []);
const unloadSingleModel = useCallback(async (model: string) => {
try {
const response = await fetch(`/api/models/unload/${model}`, {
method: "POST",
});
if (!response.ok) {
throw new Error(`Failed to unload model: ${response.status}`);
}
} catch (error) {
console.error("Failed to unload model", model, error);
throw error;
}
}, []);
const loadModel = useCallback(async (model: string) => { const loadModel = useCallback(async (model: string) => {
try { try {
const response = await fetch(`/upstream/${model}/`, { const response = await fetch(`/upstream/${model}/`, {
@@ -208,6 +223,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
models, models,
listModels, listModels,
unloadAllModels, unloadAllModels,
unloadSingleModel,
loadModel, loadModel,
enableAPIEvents, enableAPIEvents,
proxyLogs, proxyLogs,

View File

@@ -4,7 +4,7 @@ import { LogPanel } from "./LogViewer";
import { usePersistentState } from "../hooks/usePersistentState"; import { usePersistentState } from "../hooks/usePersistentState";
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels"; import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
import { useTheme } from "../contexts/ThemeProvider"; import { useTheme } from "../contexts/ThemeProvider";
import { RiEyeFill, RiEyeOffFill, RiStopCircleLine, RiSwapBoxFill } from "react-icons/ri"; import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine } from "react-icons/ri";
export default function ModelsPage() { export default function ModelsPage() {
const { isNarrow } = useTheme(); const { isNarrow } = useTheme();
@@ -37,7 +37,7 @@ export default function ModelsPage() {
} }
function ModelsPanel() { function ModelsPanel() {
const { models, loadModel, unloadAllModels } = useAPI(); const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
const [isUnloading, setIsUnloading] = useState(false); const [isUnloading, setIsUnloading] = useState(false);
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true); const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
@@ -90,7 +90,7 @@ function ModelsPanel() {
onClick={handleUnloadAllModels} onClick={handleUnloadAllModels}
disabled={isUnloading} disabled={isUnloading}
> >
<RiStopCircleLine size="24" /> {isUnloading ? "Unloading..." : "Unload"} <RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
</button> </button>
</div> </div>
</div> </div>
@@ -119,13 +119,19 @@ function ModelsPanel() {
)} )}
</td> </td>
<td className="w-12"> <td className="w-12">
<button {model.state === "stopped" ? (
className="btn btn--sm" <button className="btn btn--sm" onClick={() => loadModel(model.id)}>
disabled={model.state !== "stopped"} Load
onClick={() => loadModel(model.id)} </button>
> ) : (
Load <button
</button> className="btn btn--sm"
onClick={() => unloadSingleModel(model.id)}
disabled={model.state !== "ready"}
>
Unload
</button>
)}
</td> </td>
<td className="w-20"> <td className="w-20">
<span className={`w-16 text-center status status--${model.state}`}>{model.state}</span> <span className={`w-16 text-center status status--${model.state}`}>{model.state}</span>

View File

@@ -15,6 +15,7 @@ export default defineConfig({
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development "/api": "http://localhost:8080", // Proxy API calls to Go backend during development
"/logs": "http://localhost:8080", "/logs": "http://localhost:8080",
"/upstream": "http://localhost:8080", "/upstream": "http://localhost:8080",
"/unload": "http://localhost:8080",
}, },
}, },
}); });