proxy: add unload of single model (#318)
This adds a new API endpoint, /api/models/unload/*model, that unloads a single model. In the UI when a model is in a ReadyState it will have a new button to unload it. Fixes #312
This commit is contained in:
@@ -86,6 +86,29 @@ func (pg *ProcessGroup) HasMember(modelName string) bool {
|
|||||||
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
return slices.Contains(pg.config.Groups[pg.id].Members, modelName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pg *ProcessGroup) StopProcess(modelID string, strategy StopStrategy) error {
|
||||||
|
pg.Lock()
|
||||||
|
|
||||||
|
process, exists := pg.processes[modelID]
|
||||||
|
if !exists {
|
||||||
|
pg.Unlock()
|
||||||
|
return fmt.Errorf("process not found for %s", modelID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pg.lastUsedProcess == modelID {
|
||||||
|
pg.lastUsedProcess = ""
|
||||||
|
}
|
||||||
|
pg.Unlock()
|
||||||
|
|
||||||
|
switch strategy {
|
||||||
|
case StopImmediately:
|
||||||
|
process.StopImmediately()
|
||||||
|
default:
|
||||||
|
process.Stop()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
func (pg *ProcessGroup) StopProcesses(strategy StopStrategy) {
|
||||||
pg.Lock()
|
pg.Lock()
|
||||||
defer pg.Unlock()
|
defer pg.Unlock()
|
||||||
|
|||||||
@@ -228,7 +228,6 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
c.Redirect(http.StatusFound, "/ui/models")
|
c.Redirect(http.StatusFound, "/ui/models")
|
||||||
})
|
})
|
||||||
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
pm.ginEngine.Any("/upstream/*upstreamPath", pm.proxyToUpstream)
|
||||||
|
|
||||||
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
pm.ginEngine.GET("/unload", pm.unloadAllModelsHandler)
|
||||||
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
pm.ginEngine.GET("/running", pm.listRunningProcessesHandler)
|
||||||
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
pm.ginEngine.GET("/health", func(c *gin.Context) {
|
||||||
|
|||||||
@@ -3,8 +3,10 @@ package proxy
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
@@ -23,6 +25,7 @@ func addApiHandlers(pm *ProxyManager) {
|
|||||||
apiGroup := pm.ginEngine.Group("/api")
|
apiGroup := pm.ginEngine.Group("/api")
|
||||||
{
|
{
|
||||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||||
|
apiGroup.POST("/models/unload/*model", pm.apiUnloadSingleModelHandler)
|
||||||
apiGroup.GET("/events", pm.apiSendEvents)
|
apiGroup.GET("/events", pm.apiSendEvents)
|
||||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||||
}
|
}
|
||||||
@@ -202,3 +205,25 @@ func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
c.Data(http.StatusOK, "application/json", jsonData)
|
c.Data(http.StatusOK, "application/json", jsonData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (pm *ProxyManager) apiUnloadSingleModelHandler(c *gin.Context) {
|
||||||
|
requestedModel := strings.TrimPrefix(c.Param("model"), "/")
|
||||||
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||||
|
if !found {
|
||||||
|
pm.sendErrorResponse(c, http.StatusNotFound, "Model not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
processGroup := pm.findGroupByModelName(realModelName)
|
||||||
|
if processGroup == nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("process group not found for model %s", requestedModel))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := processGroup.StopProcess(realModelName, StopImmediately); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error stopping process: %s", err.Error()))
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
c.String(http.StatusOK, "OK")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -401,11 +401,65 @@ func TestProxyManager_Unload(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Equal(t, w.Body.String(), "OK")
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
|
|
||||||
// give it a bit of time to stop
|
select {
|
||||||
<-time.After(time.Millisecond * 250)
|
case <-proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].cmdWaitChan:
|
||||||
|
// good
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for model1 to stop")
|
||||||
|
}
|
||||||
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
assert.Equal(t, proxy.processGroups[DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateStopped)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
||||||
|
const testGroupId = "testGroup"
|
||||||
|
config := AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
},
|
||||||
|
Groups: map[string]GroupConfig{
|
||||||
|
testGroupId: {
|
||||||
|
Swap: false,
|
||||||
|
Members: []string{"model1", "model2"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses(StopImmediately)
|
||||||
|
|
||||||
|
// start both model
|
||||||
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model1"].CurrentState())
|
||||||
|
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
proxy.ServeHTTP(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
if !assert.Equal(t, w.Body.String(), "OK") {
|
||||||
|
t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-proxy.processGroups[testGroupId].processes["model1"].cmdWaitChan:
|
||||||
|
// good
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for model1 to stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, proxy.processGroups[testGroupId].processes["model1"].CurrentState(), StateStopped)
|
||||||
|
assert.Equal(t, proxy.processGroups[testGroupId].processes["model2"].CurrentState(), StateReady)
|
||||||
|
}
|
||||||
|
|
||||||
// Test issue #61 `Listing the current list of models and the loaded model.`
|
// Test issue #61 `Listing the current list of models and the loaded model.`
|
||||||
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
func TestProxyManager_RunningEndpoint(t *testing.T) {
|
||||||
// Shared configuration
|
// Shared configuration
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ interface APIProviderType {
|
|||||||
models: Model[];
|
models: Model[];
|
||||||
listModels: () => Promise<Model[]>;
|
listModels: () => Promise<Model[]>;
|
||||||
unloadAllModels: () => Promise<void>;
|
unloadAllModels: () => Promise<void>;
|
||||||
|
unloadSingleModel: (model: string) => Promise<void>;
|
||||||
loadModel: (model: string) => Promise<void>;
|
loadModel: (model: string) => Promise<void>;
|
||||||
enableAPIEvents: (enabled: boolean) => void;
|
enableAPIEvents: (enabled: boolean) => void;
|
||||||
proxyLogs: string;
|
proxyLogs: string;
|
||||||
@@ -177,7 +178,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
|||||||
|
|
||||||
const unloadAllModels = useCallback(async () => {
|
const unloadAllModels = useCallback(async () => {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`/api/models/unload/`, {
|
const response = await fetch(`/api/models/unload`, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
});
|
});
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
@@ -189,6 +190,20 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
|||||||
}
|
}
|
||||||
}, []);
|
}, []);
|
||||||
|
|
||||||
|
const unloadSingleModel = useCallback(async (model: string) => {
|
||||||
|
try {
|
||||||
|
const response = await fetch(`/api/models/unload/${model}`, {
|
||||||
|
method: "POST",
|
||||||
|
});
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Failed to unload model: ${response.status}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Failed to unload model", model, error);
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
}, []);
|
||||||
|
|
||||||
const loadModel = useCallback(async (model: string) => {
|
const loadModel = useCallback(async (model: string) => {
|
||||||
try {
|
try {
|
||||||
const response = await fetch(`/upstream/${model}/`, {
|
const response = await fetch(`/upstream/${model}/`, {
|
||||||
@@ -208,6 +223,7 @@ export function APIProvider({ children, autoStartAPIEvents = true }: APIProvider
|
|||||||
models,
|
models,
|
||||||
listModels,
|
listModels,
|
||||||
unloadAllModels,
|
unloadAllModels,
|
||||||
|
unloadSingleModel,
|
||||||
loadModel,
|
loadModel,
|
||||||
enableAPIEvents,
|
enableAPIEvents,
|
||||||
proxyLogs,
|
proxyLogs,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import { LogPanel } from "./LogViewer";
|
|||||||
import { usePersistentState } from "../hooks/usePersistentState";
|
import { usePersistentState } from "../hooks/usePersistentState";
|
||||||
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
import { Panel, PanelGroup, PanelResizeHandle } from "react-resizable-panels";
|
||||||
import { useTheme } from "../contexts/ThemeProvider";
|
import { useTheme } from "../contexts/ThemeProvider";
|
||||||
import { RiEyeFill, RiEyeOffFill, RiStopCircleLine, RiSwapBoxFill } from "react-icons/ri";
|
import { RiEyeFill, RiEyeOffFill, RiSwapBoxFill, RiEjectLine } from "react-icons/ri";
|
||||||
|
|
||||||
export default function ModelsPage() {
|
export default function ModelsPage() {
|
||||||
const { isNarrow } = useTheme();
|
const { isNarrow } = useTheme();
|
||||||
@@ -37,7 +37,7 @@ export default function ModelsPage() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function ModelsPanel() {
|
function ModelsPanel() {
|
||||||
const { models, loadModel, unloadAllModels } = useAPI();
|
const { models, loadModel, unloadAllModels, unloadSingleModel } = useAPI();
|
||||||
const [isUnloading, setIsUnloading] = useState(false);
|
const [isUnloading, setIsUnloading] = useState(false);
|
||||||
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
const [showUnlisted, setShowUnlisted] = usePersistentState("showUnlisted", true);
|
||||||
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
|
const [showIdorName, setShowIdorName] = usePersistentState<"id" | "name">("showIdorName", "id"); // true = show ID, false = show name
|
||||||
@@ -90,7 +90,7 @@ function ModelsPanel() {
|
|||||||
onClick={handleUnloadAllModels}
|
onClick={handleUnloadAllModels}
|
||||||
disabled={isUnloading}
|
disabled={isUnloading}
|
||||||
>
|
>
|
||||||
<RiStopCircleLine size="24" /> {isUnloading ? "Unloading..." : "Unload"}
|
<RiEjectLine size="24" /> {isUnloading ? "Unloading..." : "Unload All"}
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -119,13 +119,19 @@ function ModelsPanel() {
|
|||||||
)}
|
)}
|
||||||
</td>
|
</td>
|
||||||
<td className="w-12">
|
<td className="w-12">
|
||||||
<button
|
{model.state === "stopped" ? (
|
||||||
className="btn btn--sm"
|
<button className="btn btn--sm" onClick={() => loadModel(model.id)}>
|
||||||
disabled={model.state !== "stopped"}
|
|
||||||
onClick={() => loadModel(model.id)}
|
|
||||||
>
|
|
||||||
Load
|
Load
|
||||||
</button>
|
</button>
|
||||||
|
) : (
|
||||||
|
<button
|
||||||
|
className="btn btn--sm"
|
||||||
|
onClick={() => unloadSingleModel(model.id)}
|
||||||
|
disabled={model.state !== "ready"}
|
||||||
|
>
|
||||||
|
Unload
|
||||||
|
</button>
|
||||||
|
)}
|
||||||
</td>
|
</td>
|
||||||
<td className="w-20">
|
<td className="w-20">
|
||||||
<span className={`w-16 text-center status status--${model.state}`}>{model.state}</span>
|
<span className={`w-16 text-center status status--${model.state}`}>{model.state}</span>
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ export default defineConfig({
|
|||||||
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
|
"/api": "http://localhost:8080", // Proxy API calls to Go backend during development
|
||||||
"/logs": "http://localhost:8080",
|
"/logs": "http://localhost:8080",
|
||||||
"/upstream": "http://localhost:8080",
|
"/upstream": "http://localhost:8080",
|
||||||
|
"/unload": "http://localhost:8080",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user