diff --git a/go.mod b/go.mod index d7b3ac1..f4e708c 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 gopkg.in/yaml.v3 v3.0.1 + github.com/kelindar/event v1.5.2 ) require ( diff --git a/go.sum b/go.sum index 6a867c0..92518f5 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaU github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/kelindar/event v1.5.2 h1:qtgssZqMh/QQMCIxlbx4wU3DoMHOrJXKdiZhphJ4YbY= +github.com/kelindar/event v1.5.2/go.mod h1:UxWPQjWK8u0o9Z3ponm2mgREimM95hm26/M9z8F488Q= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= diff --git a/llama-swap.go b/llama-swap.go index f801ad7..15e0299 100644 --- a/llama-swap.go +++ b/llama-swap.go @@ -14,6 +14,7 @@ import ( "github.com/fsnotify/fsnotify" "github.com/gin-gonic/gin" + "github.com/kelindar/event" "github.com/mostlygeek/llama-swap/proxy" ) @@ -53,144 +54,129 @@ func main() { gin.SetMode(gin.ReleaseMode) } - proxyManager := proxy.New(config) - // Setup channels for server management - reloadChan := make(chan *proxy.ProxyManager) exitChan := make(chan struct{}) sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) // Create server with initial handler srv := &http.Server{ - Addr: *listenStr, - Handler: proxyManager, + Addr: *listenStr, } + // Support for watching config and reloading when it changes + reloadProxyManager := func() { + if currentPM, ok := srv.Handler.(*proxy.ProxyManager); ok { + config, err = proxy.LoadConfig(*configPath) + if err != nil { + fmt.Printf("Warning, unable to reload configuration: %v\n", err) + return + } + + fmt.Println("Configuration Changed") + currentPM.Shutdown() + srv.Handler = proxy.New(config) + fmt.Println("Configuration Reloaded") + + // wait a few seconds and tell any UI to reload + time.AfterFunc(3*time.Second, func() { + event.Emit(proxy.ConfigFileChangedEvent{ + ReloadingState: proxy.ReloadingStateEnd, + }) + }) + } else { + config, err = proxy.LoadConfig(*configPath) + if err != nil { + fmt.Printf("Error, unable to load configuration: %v\n", err) + os.Exit(1) + } + srv.Handler = proxy.New(config) + } + } + + // load the initial proxy manager + reloadProxyManager() + debouncedReload := debounce(time.Second, reloadProxyManager) + if *watchConfig { + defer event.On(func(e proxy.ConfigFileChangedEvent) { + if e.ReloadingState == proxy.ReloadingStateStart { + debouncedReload() + } + })() + + fmt.Println("Watching Configuration for changes") + go func() { + absConfigPath, err := filepath.Abs(*configPath) + if err != nil { + fmt.Printf("Error getting absolute path for watching config file: %v\n", err) + return + } + watcher, err := fsnotify.NewWatcher() + if err != nil { + fmt.Printf("Error creating file watcher: %v. File watching disabled.\n", err) + return + } + + err = watcher.Add(absConfigPath) + if err != nil { + fmt.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", absConfigPath, err) + return + } + + defer watcher.Close() + for { + select { + case changeEvent := <-watcher.Events: + if changeEvent.Name == absConfigPath && (changeEvent.Has(fsnotify.Write) || changeEvent.Has(fsnotify.Create) || changeEvent.Has(fsnotify.Remove)) { + event.Emit(proxy.ConfigFileChangedEvent{ + ReloadingState: proxy.ReloadingStateStart, + }) + } + + case err := <-watcher.Errors: + log.Printf("File watcher error: %v", err) + } + } + }() + } + + // shutdown on signal + go func() { + sig := <-sigChan + fmt.Printf("Received signal %v, shutting down...\n", sig) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + if pm, ok := srv.Handler.(*proxy.ProxyManager); ok { + pm.Shutdown() + } else { + fmt.Println("srv.Handler is not of type *proxy.ProxyManager") + } + + if err := srv.Shutdown(ctx); err != nil { + fmt.Printf("Server shutdown error: %v\n", err) + } + close(exitChan) + }() + // Start server fmt.Printf("llama-swap listening on %s\n", *listenStr) go func() { if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - fmt.Printf("Fatal server error: %v\n", err) - close(exitChan) + log.Fatalf("Fatal server error: %v\n", err) } }() - // Handle config reloads and signals - go func() { - currentManager := proxyManager - for { - select { - case newManager := <-reloadChan: - log.Println("Config change detected, waiting for in-flight requests to complete...") - // Stop old manager processes gracefully (this waits for in-flight requests) - currentManager.StopProcesses(proxy.StopWaitForInflightRequest) - // Now do a full shutdown to clear the process map - currentManager.Shutdown() - currentManager = newManager - srv.Handler = newManager - log.Println("Server handler updated with new config") - case sig := <-sigChan: - fmt.Printf("Received signal %v, shutting down...\n", sig) - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - currentManager.Shutdown() - if err := srv.Shutdown(ctx); err != nil { - fmt.Printf("Server shutdown error: %v\n", err) - } - close(exitChan) - return - } - } - }() - - // Start file watcher if requested - if *watchConfig { - absConfigPath, err := filepath.Abs(*configPath) - if err != nil { - log.Printf("Error getting absolute path for config: %v. File watching disabled.", err) - } else { - go watchConfigFileWithReload(absConfigPath, reloadChan) - } - } - // Wait for exit signal <-exitChan } -// watchConfigFileWithReload monitors the configuration file and sends new ProxyManager instances through reloadChan. -func watchConfigFileWithReload(configPath string, reloadChan chan<- *proxy.ProxyManager) { - watcher, err := fsnotify.NewWatcher() - if err != nil { - log.Printf("Error creating file watcher: %v. File watching disabled.", err) - return - } - defer watcher.Close() - - err = watcher.Add(configPath) - if err != nil { - log.Printf("Error adding config path (%s) to watcher: %v. File watching disabled.", configPath, err) - return - } - - log.Printf("Watching config file for changes: %s", configPath) - - var debounceTimer *time.Timer - debounceDuration := 2 * time.Second - - for { - select { - case event, ok := <-watcher.Events: - if !ok { - return - } - // We only care about writes/creates to the specific config file - if event.Name == configPath && (event.Has(fsnotify.Write) || event.Has(fsnotify.Create) || event.Has(fsnotify.Remove)) { - // Reset or start the debounce timer - if debounceTimer != nil { - debounceTimer.Stop() - } - debounceTimer = time.AfterFunc(debounceDuration, func() { - log.Printf("Config file modified: %s, reloading...", event.Name) - - // Try up to 3 times with exponential backoff - var newConfig proxy.Config - var err error - for retries := 0; retries < 3; retries++ { - // Load new configuration - newConfig, err = proxy.LoadConfig(configPath) - if err == nil { - break - } - log.Printf("Error loading new config (attempt %d/3): %v", retries+1, err) - if retries < 2 { - time.Sleep(time.Duration(1< swapState() State transitioned from %s to %s", p.ID, expectedState, newState) + event.Emit(ProcessStateChangeEvent{ProcessName: p.ID, NewState: newState, OldState: expectedState}) return p.state, nil } diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index bbb8101..8a50a70 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -2,6 +2,7 @@ package proxy import ( "bytes" + "context" "fmt" "io" "mime/multipart" @@ -33,6 +34,10 @@ type ProxyManager struct { muxLogger *LogMonitor processGroups map[string]*ProcessGroup + + // shutdown signaling + shutdownCtx context.Context + shutdownCancel context.CancelFunc } func New(config Config) *ProxyManager { @@ -63,6 +68,8 @@ func New(config Config) *ProxyManager { upstreamLogger.SetLogLevel(LevelInfo) } + shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + pm := &ProxyManager{ config: config, ginEngine: gin.New(), @@ -72,6 +79,9 @@ func New(config Config) *ProxyManager { upstreamLogger: upstreamLogger, processGroups: make(map[string]*ProcessGroup), + + shutdownCtx: shutdownCtx, + shutdownCancel: shutdownCancel, } // create the process groups @@ -157,9 +167,7 @@ func (pm *ProxyManager) setupGinEngine() { // in proxymanager_loghandlers.go pm.ginEngine.GET("/logs", pm.sendLogsHandlers) pm.ginEngine.GET("/logs/stream", pm.streamLogsHandler) - pm.ginEngine.GET("/logs/streamSSE", pm.streamLogsHandlerSSE) pm.ginEngine.GET("/logs/stream/:logMonitorID", pm.streamLogsHandler) - pm.ginEngine.GET("/logs/streamSSE/:logMonitorID", pm.streamLogsHandlerSSE) /** * User Interface Endpoints @@ -261,6 +269,7 @@ func (pm *ProxyManager) Shutdown() { }(processGroup) } wg.Wait() + pm.shutdownCancel() } func (pm *ProxyManager) swapProcessGroup(requestedModel string) (*ProcessGroup, string, error) { diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go index d2da0c7..e40582a 100644 --- a/proxy/proxymanager_api.go +++ b/proxy/proxymanager_api.go @@ -1,25 +1,28 @@ package proxy import ( + "context" + "encoding/json" "net/http" "sort" - "time" "github.com/gin-gonic/gin" + "github.com/kelindar/event" ) type Model struct { - Id string `json:"id"` - State string `json:"state"` + Id string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + State string `json:"state"` } func addApiHandlers(pm *ProxyManager) { // Add API endpoints for React to consume apiGroup := pm.ginEngine.Group("/api") { - apiGroup.GET("/models", pm.apiListModels) - apiGroup.GET("/modelsSSE", pm.apiListModelsSSE) apiGroup.POST("/models/unload", pm.apiUnloadAllModels) + apiGroup.GET("/events", pm.apiSendEvents) } } @@ -65,37 +68,102 @@ func (pm *ProxyManager) getModelStatus() []Model { } } models = append(models, Model{ - Id: modelID, - State: state, + Id: modelID, + Name: pm.config.Models[modelID].Name, + Description: pm.config.Models[modelID].Description, + State: state, }) } return models } -func (pm *ProxyManager) apiListModels(c *gin.Context) { - c.JSON(http.StatusOK, pm.getModelStatus()) +type messageType string + +const ( + msgTypeModelStatus messageType = "modelStatus" + msgTypeLogData messageType = "logData" +) + +type messageEnvelope struct { + Type messageType `json:"type"` + Data string `json:"data"` } -// stream the models as a SSE -func (pm *ProxyManager) apiListModelsSSE(c *gin.Context) { +// sends a stream of different message types that happen on the server +func (pm *ProxyManager) apiSendEvents(c *gin.Context) { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Content-Type-Options", "nosniff") - notify := c.Request.Context().Done() + sendBuffer := make(chan messageEnvelope, 25) + ctx, cancel := context.WithCancel(c.Request.Context()) + sendModels := func() { + data, err := json.Marshal(pm.getModelStatus()) + if err == nil { + msg := messageEnvelope{Type: msgTypeModelStatus, Data: string(data)} + select { + case sendBuffer <- msg: + case <-ctx.Done(): + return + default: + } + + } + } + + sendLogData := func(source string, data []byte) { + data, err := json.Marshal(gin.H{ + "source": source, + "data": string(data), + }) + if err == nil { + select { + case sendBuffer <- messageEnvelope{Type: msgTypeLogData, Data: string(data)}: + case <-ctx.Done(): + return + default: + } + } + } + + /** + * Send updated models list + */ + defer event.On(func(e ProcessStateChangeEvent) { + sendModels() + })() + defer event.On(func(e ConfigFileChangedEvent) { + sendModels() + })() + + /** + * Send Log data + */ + defer pm.proxyLogger.OnLogData(func(data []byte) { + sendLogData("proxy", data) + })() + defer pm.upstreamLogger.OnLogData(func(data []byte) { + sendLogData("upstream", data) + })() + + // send initial batch of data + sendLogData("proxy", pm.proxyLogger.GetHistory()) + sendLogData("upstream", pm.upstreamLogger.GetHistory()) + sendModels() - // Stream new events for { select { - case <-notify: + case <-c.Request.Context().Done(): + cancel() return - default: - models := pm.getModelStatus() - c.SSEvent("message", models) + case <-pm.shutdownCtx.Done(): + cancel() + return + case msg := <-sendBuffer: + c.SSEvent("message", msg) c.Writer.Flush() - <-time.After(1000 * time.Millisecond) } } } diff --git a/proxy/proxymanager_loghandlers.go b/proxy/proxymanager_loghandlers.go index cf18efa..d466672 100644 --- a/proxy/proxymanager_loghandlers.go +++ b/proxy/proxymanager_loghandlers.go @@ -1,6 +1,7 @@ package proxy import ( + "context" "fmt" "net/http" "strings" @@ -34,10 +35,7 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) { c.String(http.StatusBadRequest, err.Error()) return } - ch := logger.Subscribe() - defer logger.Unsubscribe(ch) - notify := c.Request.Context().Done() flusher, ok := c.Writer.(http.Flusher) if !ok { c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("streaming unsupported")) @@ -55,57 +53,28 @@ func (pm *ProxyManager) streamLogsHandler(c *gin.Context) { } } - // Stream new logs + sendChan := make(chan []byte, 10) + ctx, cancel := context.WithCancel(c.Request.Context()) + defer logger.OnLogData(func(data []byte) { + select { + case sendChan <- data: + case <-ctx.Done(): + return + default: + } + })() + for { select { - case msg := <-ch: - _, err := c.Writer.Write(msg) - if err != nil { - // just break the loop if we can't write for some reason - return - } + case <-c.Request.Context().Done(): + cancel() + return + case <-pm.shutdownCtx.Done(): + cancel() + return + case data := <-sendChan: + c.Writer.Write(data) flusher.Flush() - case <-notify: - return - } - } -} - -func (pm *ProxyManager) streamLogsHandlerSSE(c *gin.Context) { - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Content-Type-Options", "nosniff") - - logMonitorId := c.Param("logMonitorID") - logger, err := pm.getLogger(logMonitorId) - if err != nil { - c.String(http.StatusBadRequest, err.Error()) - return - } - ch := logger.Subscribe() - defer logger.Unsubscribe(ch) - - notify := c.Request.Context().Done() - - // Send history first if not skipped - _, skipHistory := c.GetQuery("no-history") - if !skipHistory { - history := logger.GetHistory() - if len(history) != 0 { - c.SSEvent("message", string(history)) - c.Writer.Flush() - } - } - - // Stream new logs - for { - select { - case msg := <-ch: - c.SSEvent("message", string(msg)) - c.Writer.Flush() - case <-notify: - return } } } diff --git a/ui/src/contexts/APIProvider.tsx b/ui/src/contexts/APIProvider.tsx index b37859d..fea35c3 100644 --- a/ui/src/contexts/APIProvider.tsx +++ b/ui/src/contexts/APIProvider.tsx @@ -6,6 +6,8 @@ const LOG_LENGTH_LIMIT = 1024 * 100; /* 100KB of log data */ export interface Model { id: string; state: ModelStatus; + name: string; + description: string; } interface APIProviderType { @@ -13,12 +15,18 @@ interface APIProviderType { listModels: () => Promise; unloadAllModels: () => Promise; loadModel: (model: string) => Promise; - enableProxyLogs: (enabled: boolean) => void; - enableUpstreamLogs: (enabled: boolean) => void; - enableModelUpdates: (enabled: boolean) => void; + enableAPIEvents: (enabled: boolean) => void; proxyLogs: string; upstreamLogs: string; } +interface LogData { + source: "upstream" | "proxy"; + data: string; +} +interface APIEventEnvelope { + type: "modelStatus" | "logData"; + data: string; +} const APIContext = createContext(undefined); type APIProviderProps = { @@ -30,6 +38,7 @@ export function APIProvider({ children }: APIProviderProps) { const [upstreamLogs, setUpstreamLogs] = useState(""); const proxyEventSource = useRef(null); const upstreamEventSource = useRef(null); + const apiEventSource = useRef(null); const [models, setModels] = useState([]); const modelStatusEventSource = useRef(null); @@ -41,104 +50,61 @@ export function APIProvider({ children }: APIProviderProps) { }); }, []); - const handleProxyMessage = useCallback( - (e: MessageEvent) => { - appendLog(e.data, setProxyLogs); - }, - [proxyLogs, appendLog] - ); + const enableAPIEvents = useCallback((enabled: boolean) => { + if (!enabled) { + apiEventSource.current?.close(); + apiEventSource.current = null; + return; + } - const handleUpstreamMessage = useCallback( - (e: MessageEvent) => { - appendLog(e.data, setUpstreamLogs); - }, - [appendLog] - ); + let retryCount = 0; + const maxRetries = 3; + const initialDelay = 1000; // 1 second - const enableProxyLogs = useCallback( - (enabled: boolean) => { - if (enabled) { - let retryCount = 0; - const maxRetries = 3; - const initialDelay = 1000; // 1 second + const connect = () => { + const eventSource = new EventSource("/api/events"); - const connect = () => { - const eventSource = new EventSource("/logs/streamSSE/proxy"); + eventSource.onmessage = (e: MessageEvent) => { + try { + const message = JSON.parse(e.data) as APIEventEnvelope; + switch (message.type) { + case "modelStatus": + { + const models = JSON.parse(message.data) as Model[]; + setModels(models); + } + break; - eventSource.onmessage = handleProxyMessage; - eventSource.onerror = () => { - eventSource.close(); - if (retryCount < maxRetries) { - retryCount++; - const delay = initialDelay * Math.pow(2, retryCount - 1); - setTimeout(connect, delay); + case "logData": { + const logData = JSON.parse(message.data) as LogData; + switch (logData.source) { + case "proxy": + appendLog(logData.data, setProxyLogs); + break; + case "upstream": + appendLog(logData.data, setUpstreamLogs); + break; + } } - }; - - proxyEventSource.current = eventSource; - }; - - connect(); - } else { - proxyEventSource.current?.close(); - proxyEventSource.current = null; - } - }, - [handleProxyMessage] - ); - - const enableUpstreamLogs = useCallback( - (enabled: boolean) => { - if (enabled) { - let retryCount = 0; - const maxRetries = 3; - const initialDelay = 1000; // 1 second - - const connect = () => { - const eventSource = new EventSource("/logs/streamSSE/upstream"); - - eventSource.onmessage = handleUpstreamMessage; - eventSource.onerror = () => { - eventSource.close(); - if (retryCount < maxRetries) { - retryCount++; - const delay = initialDelay * Math.pow(2, retryCount - 1); - setTimeout(connect, delay); - } - }; - - upstreamEventSource.current = eventSource; - }; - - connect(); - } else { - upstreamEventSource.current?.close(); - upstreamEventSource.current = null; - } - }, - [handleUpstreamMessage] - ); - - const enableModelUpdates = useCallback( - (enabled: boolean) => { - if (enabled) { - const eventSource = new EventSource("/api/modelsSSE"); - eventSource.onmessage = (e: MessageEvent) => { - try { - const models = JSON.parse(e.data) as Model[]; - setModels(models); - } catch (e) { - console.error(e); } - }; - modelStatusEventSource.current = eventSource; - } else { - modelStatusEventSource.current?.close(); - modelStatusEventSource.current = null; - } - }, - [setModels] - ); + } catch (err) { + console.error(e.data, err); + } + }; + eventSource.onerror = () => { + eventSource.close(); + if (retryCount < maxRetries) { + retryCount++; + const delay = initialDelay * Math.pow(2, retryCount - 1); + setTimeout(connect, delay); + } + }; + + apiEventSource.current = eventSource; + }; + + connect(); + }, []); useEffect(() => { return () => { @@ -196,23 +162,11 @@ export function APIProvider({ children }: APIProviderProps) { listModels, unloadAllModels, loadModel, - enableProxyLogs, - enableUpstreamLogs, - enableModelUpdates, + enableAPIEvents, proxyLogs, upstreamLogs, }), - [ - models, - listModels, - unloadAllModels, - loadModel, - enableProxyLogs, - enableUpstreamLogs, - enableModelUpdates, - proxyLogs, - upstreamLogs, - ] + [models, listModels, unloadAllModels, loadModel, enableAPIEvents, proxyLogs, upstreamLogs] ); return {children}; diff --git a/ui/src/pages/LogViewer.tsx b/ui/src/pages/LogViewer.tsx index cf46a51..fc31dba 100644 --- a/ui/src/pages/LogViewer.tsx +++ b/ui/src/pages/LogViewer.tsx @@ -3,14 +3,12 @@ import { useAPI } from "../contexts/APIProvider"; import { usePersistentState } from "../hooks/usePersistentState"; const LogViewer = () => { - const { proxyLogs, upstreamLogs, enableProxyLogs, enableUpstreamLogs } = useAPI(); + const { proxyLogs, upstreamLogs, enableAPIEvents } = useAPI(); useEffect(() => { - enableProxyLogs(true); - enableUpstreamLogs(true); + enableAPIEvents(true); return () => { - enableProxyLogs(false); - enableUpstreamLogs(false); + enableAPIEvents(false); }; }, []); diff --git a/ui/src/pages/Models.tsx b/ui/src/pages/Models.tsx index a103fcf..1fd8bee 100644 --- a/ui/src/pages/Models.tsx +++ b/ui/src/pages/Models.tsx @@ -4,15 +4,13 @@ import { LogPanel } from "./LogViewer"; import { processEvalTimes } from "../lib/Utils"; export default function ModelsPage() { - const { models, enableModelUpdates, unloadAllModels, loadModel, upstreamLogs, enableUpstreamLogs } = useAPI(); + const { models, unloadAllModels, loadModel, upstreamLogs, enableAPIEvents } = useAPI(); const [isUnloading, setIsUnloading] = useState(false); useEffect(() => { - enableModelUpdates(true); - enableUpstreamLogs(true); + enableAPIEvents(true); return () => { - enableModelUpdates(false); - enableUpstreamLogs(false); + enableAPIEvents(false); }; }, []); @@ -57,8 +55,13 @@ export default function ModelsPage() { - {model.id} + {model.name !== "" ? model.name : model.id} + {model.description != "" && ( +

+ {model.description} +

+ )}