diff --git a/config.example.yaml b/config.example.yaml index 189352b..03b9ee6 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -35,6 +35,14 @@ metricsMaxInMemory: 1000 # - it is automatically incremented for every model that uses it startPort: 10001 +# sendLoadingState: inject loading status updates into the reasoning (thinking) +# field +# - optional, default: false +# - when true, a stream of loading messages will be sent to the client in the +# reasoning field so chat UIs can show that loading is in progress. +# - see #366 for more details +sendLoadingState: true + # macros: a dictionary of string substitutions # - optional, default: empty dictionary # - macros are reusable snippets @@ -184,6 +192,10 @@ models: # - recommended to be omitted and the default used concurrencyLimit: 0 + # sendLoadingState: overrides the global sendLoadingState setting for this model + # - optional, default: undefined (use global setting) + sendLoadingState: false + # Unlisted model example: "qwen-unlisted": # unlisted: boolean, true or false diff --git a/proxy/config/config.go b/proxy/config/config.go index c6f478c..d957cd4 100644 --- a/proxy/config/config.go +++ b/proxy/config/config.go @@ -129,6 +129,9 @@ type Config struct { // hooks, see: #209 Hooks HooksConfig `yaml:"hooks"` + + // send loading state in reasoning + SendLoadingState bool `yaml:"sendLoadingState"` } func (c *Config) RealModelName(search string) (string, bool) { @@ -350,6 +353,13 @@ func LoadConfigFromReader(r io.Reader) (Config, error) { ) } + // if sendLoadingState is nil, set it to the global config value + // see #366 + if modelConfig.SendLoadingState == nil { + v := config.SendLoadingState // copy it + modelConfig.SendLoadingState = &v + } + config.Models[modelId] = modelConfig } diff --git a/proxy/config/config_posix_test.go b/proxy/config/config_posix_test.go index abf0ea7..d4b63b8 100644 --- a/proxy/config/config_posix_test.go +++ b/proxy/config/config_posix_test.go @@ -160,6 +160,8 @@ groups: t.Fatalf("Failed to load config: %v", err) } + modelLoadingState := false + expected := Config{ LogLevel: "info", StartPort: 5800, @@ -171,36 +173,41 @@ groups: Preload: []string{"model1", "model2"}, }, }, + SendLoadingState: false, Models: map[string]ModelConfig{ "model1": { - Cmd: "path/to/cmd --arg1 one", - Proxy: "http://localhost:8080", - Aliases: []string{"m1", "model-one"}, - Env: []string{"VAR1=value1", "VAR2=value2"}, - CheckEndpoint: "/health", - Name: "Model 1", - Description: "This is model 1", + Cmd: "path/to/cmd --arg1 one", + Proxy: "http://localhost:8080", + Aliases: []string{"m1", "model-one"}, + Env: []string{"VAR1=value1", "VAR2=value2"}, + CheckEndpoint: "/health", + Name: "Model 1", + Description: "This is model 1", + SendLoadingState: &modelLoadingState, }, "model2": { - Cmd: "path/to/server --arg1 one", - Proxy: "http://localhost:8081", - Aliases: []string{"m2"}, - Env: []string{}, - CheckEndpoint: "/", + Cmd: "path/to/server --arg1 one", + Proxy: "http://localhost:8081", + Aliases: []string{"m2"}, + Env: []string{}, + CheckEndpoint: "/", + SendLoadingState: &modelLoadingState, }, "model3": { - Cmd: "path/to/cmd --arg1 one", - Proxy: "http://localhost:8081", - Aliases: []string{"mthree"}, - Env: []string{}, - CheckEndpoint: "/", + Cmd: "path/to/cmd --arg1 one", + Proxy: "http://localhost:8081", + Aliases: []string{"mthree"}, + Env: []string{}, + CheckEndpoint: "/", + SendLoadingState: &modelLoadingState, }, "model4": { - Cmd: "path/to/cmd --arg1 one", - Proxy: "http://localhost:8082", - CheckEndpoint: "/", - Aliases: []string{}, - Env: []string{}, + Cmd: "path/to/cmd --arg1 one", + Proxy: "http://localhost:8082", + CheckEndpoint: "/", + Aliases: []string{}, + Env: []string{}, + SendLoadingState: &modelLoadingState, }, }, HealthCheckTimeout: 15, diff --git a/proxy/config/config_windows_test.go b/proxy/config/config_windows_test.go index 20ff692..c2136e4 100644 --- a/proxy/config/config_windows_test.go +++ b/proxy/config/config_windows_test.go @@ -152,44 +152,51 @@ groups: t.Fatalf("Failed to load config: %v", err) } + modelLoadingState := false + expected := Config{ LogLevel: "info", StartPort: 5800, Macros: MacroList{ {"svr-path", "path/to/server"}, }, + SendLoadingState: false, Models: map[string]ModelConfig{ "model1": { - Cmd: "path/to/cmd --arg1 one", - CmdStop: "taskkill /f /t /pid ${PID}", - Proxy: "http://localhost:8080", - Aliases: []string{"m1", "model-one"}, - Env: []string{"VAR1=value1", "VAR2=value2"}, - CheckEndpoint: "/health", + Cmd: "path/to/cmd --arg1 one", + CmdStop: "taskkill /f /t /pid ${PID}", + Proxy: "http://localhost:8080", + Aliases: []string{"m1", "model-one"}, + Env: []string{"VAR1=value1", "VAR2=value2"}, + CheckEndpoint: "/health", + SendLoadingState: &modelLoadingState, }, "model2": { - Cmd: "path/to/server --arg1 one", - CmdStop: "taskkill /f /t /pid ${PID}", - Proxy: "http://localhost:8081", - Aliases: []string{"m2"}, - Env: []string{}, - CheckEndpoint: "/", + Cmd: "path/to/server --arg1 one", + CmdStop: "taskkill /f /t /pid ${PID}", + Proxy: "http://localhost:8081", + Aliases: []string{"m2"}, + Env: []string{}, + CheckEndpoint: "/", + SendLoadingState: &modelLoadingState, }, "model3": { - Cmd: "path/to/cmd --arg1 one", - CmdStop: "taskkill /f /t /pid ${PID}", - Proxy: "http://localhost:8081", - Aliases: []string{"mthree"}, - Env: []string{}, - CheckEndpoint: "/", + Cmd: "path/to/cmd --arg1 one", + CmdStop: "taskkill /f /t /pid ${PID}", + Proxy: "http://localhost:8081", + Aliases: []string{"mthree"}, + Env: []string{}, + CheckEndpoint: "/", + SendLoadingState: &modelLoadingState, }, "model4": { - Cmd: "path/to/cmd --arg1 one", - CmdStop: "taskkill /f /t /pid ${PID}", - Proxy: "http://localhost:8082", - CheckEndpoint: "/", - Aliases: []string{}, - Env: []string{}, + Cmd: "path/to/cmd --arg1 one", + CmdStop: "taskkill /f /t /pid ${PID}", + Proxy: "http://localhost:8082", + CheckEndpoint: "/", + Aliases: []string{}, + Env: []string{}, + SendLoadingState: &modelLoadingState, }, }, HealthCheckTimeout: 15, diff --git a/proxy/config/model_config.go b/proxy/config/model_config.go index 49e78e9..f1c79e3 100644 --- a/proxy/config/model_config.go +++ b/proxy/config/model_config.go @@ -35,6 +35,9 @@ type ModelConfig struct { // Metadata: see #264 // Arbitrary metadata that can be exposed through the API Metadata map[string]any `yaml:"metadata"` + + // override global setting + SendLoadingState *bool `yaml:"sendLoadingState"` } func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { diff --git a/proxy/config/model_config_test.go b/proxy/config/model_config_test.go index 3979d01..9f1e9b4 100644 --- a/proxy/config/model_config_test.go +++ b/proxy/config/model_config_test.go @@ -50,5 +50,25 @@ models: } }) } - +} + +func TestConfig_ModelSendLoadingState(t *testing.T) { + content := ` +sendLoadingState: true +models: + model1: + cmd: path/to/cmd --port ${PORT} + sendLoadingState: false + model2: + cmd: path/to/cmd --port ${PORT} +` + config, err := LoadConfigFromReader(strings.NewReader(content)) + assert.NoError(t, err) + assert.True(t, config.SendLoadingState) + if assert.NotNil(t, config.Models["model1"].SendLoadingState) { + assert.False(t, *config.Models["model1"].SendLoadingState) + } + if assert.NotNil(t, config.Models["model2"].SendLoadingState) { + assert.True(t, *config.Models["model2"].SendLoadingState) + } } diff --git a/proxy/process.go b/proxy/process.go index 775f307..6406358 100644 --- a/proxy/process.go +++ b/proxy/process.go @@ -2,8 +2,10 @@ package proxy import ( "context" + "encoding/json" "errors" "fmt" + "math/rand" "net" "net/http" "net/http/httputil" @@ -462,6 +464,12 @@ func (p *Process) checkHealthEndpoint(healthURL string) error { } func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { + + if p.reverseProxy == nil { + http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError) + return + } + requestBeginTime := time.Now() var startDuration time.Duration @@ -488,17 +496,44 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { p.inFlightRequests.Done() }() + // for #366 + // - extract streaming param from request context, should have been set by proxymanager + var srw *statusResponseWriter + swapCtx, cancelLoadCtx := context.WithCancel(r.Context()) // start the process on demand if p.CurrentState() != StateReady { + // start a goroutine to stream loading status messages into the response writer + // add a sync so the streaming client only runs when the goroutine has exited + + isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool) + if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming { + srw = newStatusResponseWriter(p, w) + go srw.statusUpdates(swapCtx) + } else { + p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID) + } + beginStartTime := time.Now() if err := p.start(); err != nil { errstr := fmt.Sprintf("unable to start process: %s", err) - http.Error(w, errstr, http.StatusBadGateway) + cancelLoadCtx() + if srw != nil { + srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr)) + // Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages + // before closing the connection. Without this, the connection would close before + // the goroutine can write its cleanup messages, causing incomplete SSE output. + srw.waitForCompletion(100 * time.Millisecond) + } else { + http.Error(w, errstr, http.StatusBadGateway) + } return } startDuration = time.Since(beginStartTime) } + // should trigger srw to stop sending loading events ... + cancelLoadCtx() + // recover from http.ErrAbortHandler panics that can occur when the client // disconnects before the response is sent defer func() { @@ -511,10 +546,15 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) { } }() - if p.reverseProxy != nil { - p.reverseProxy.ServeHTTP(w, r) + if srw != nil { + // Wait for the goroutine to finish writing its final messages + const completionTimeout = 1 * time.Second + if !srw.waitForCompletion(completionTimeout) { + p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout) + } + p.reverseProxy.ServeHTTP(srw, r) } else { - http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError) + p.reverseProxy.ServeHTTP(w, r) } totalTime := time.Since(requestBeginTime) @@ -600,3 +640,220 @@ func (p *Process) cmdStopUpstreamProcess() error { return nil } + +var loadingRemarks = []string{ + "Still faster than your last standup meeting...", + "Reticulating splines...", + "Waking up the hamsters...", + "Teaching the model manners...", + "Convincing the GPU to participate...", + "Loading weights (they're heavy)...", + "Herding electrons...", + "Compiling excuses for the delay...", + "Downloading more RAM...", + "Asking the model nicely to boot up...", + "Bribing CUDA with cookies...", + "Still loading (blame VRAM)...", + "The model is fashionably late...", + "Warming up those tensors...", + "Making the neural net do push-ups...", + "Your patience is appreciated (really)...", + "Almost there (probably)...", + "Loading like it's 1999...", + "The model forgot where it put its keys...", + "Quantum tunneling through layers...", + "Negotiating with the PCIe bus...", + "Defrosting frozen parameters...", + "Teaching attention heads to focus...", + "Running the matrix (slowly)...", + "Untangling transformer blocks...", + "Calibrating the flux capacitor...", + "Spinning up the probability wheels...", + "Waiting for the GPU to wake from its nap...", + "Converting caffeine to compute...", + "Allocating virtual patience...", + "Performing arcane CUDA rituals...", + "The model is stuck in traffic...", + "Inflating embeddings...", + "Summoning computational demons...", + "Pleading with the OOM killer...", + "Calculating the meaning of life (still at 42)...", + "Training the training wheels...", + "Optimizing the optimizer...", + "Bootstrapping the bootstrapper...", + "Loading loading screen...", + "Processing processing logs...", + "Buffering buffer overflow jokes...", + "The model hit snooze...", + "Debugging the debugger...", + "Compiling the compiler...", + "Parsing the parser (meta)...", + "Tokenizing tokens...", + "Encoding the encoder...", + "Hashing hash browns...", + "Forking spoons (not forks)...", + "The model is contemplating existence...", + "Transcending dimensional barriers...", + "Invoking elder tensor gods...", + "Unfurling probability clouds...", + "Synchronizing parallel universes...", + "The GPU is having second thoughts...", + "Recalibrating reality matrices...", + "Time is an illusion, loading doubly so...", + "Convincing bits to flip themselves...", + "The model is reading its own documentation...", +} + +type statusResponseWriter struct { + hasWritten bool + writer http.ResponseWriter + process *Process + wg sync.WaitGroup // Track goroutine completion + start time.Time +} + +func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter { + s := &statusResponseWriter{ + writer: w, + process: p, + start: time.Now(), + } + + s.Header().Set("Content-Type", "text/event-stream") // SSE + s.Header().Set("Cache-Control", "no-cache") // no-cache + s.Header().Set("Connection", "keep-alive") // keep-alive + s.WriteHeader(http.StatusOK) // send status code 200 + s.sendLine("━━━━━") + s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID)) + return s +} + +// statusUpdates sends status updates to the client while the model is loading +func (s *statusResponseWriter) statusUpdates(ctx context.Context) { + s.wg.Add(1) + defer s.wg.Done() + + defer func() { + duration := time.Since(s.start) + s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds())) + s.sendLine("━━━━━") + s.sendLine(" ") + }() + + // Create a shuffled copy of loadingRemarks + remarks := make([]string, len(loadingRemarks)) + copy(remarks, loadingRemarks) + rand.Shuffle(len(remarks), func(i, j int) { + remarks[i], remarks[j] = remarks[j], remarks[i] + }) + ri := 0 + + // Pick a random duration to send a remark + nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second + lastRemarkTime := time.Now() + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if s.process.CurrentState() == StateReady { + return + } + + // Check if it's time for a snarky remark + if time.Since(lastRemarkTime) >= nextRemarkIn { + remark := remarks[ri%len(remarks)] + ri++ + s.sendLine(fmt.Sprintf("\n%s", remark)) + lastRemarkTime = time.Now() + // Pick a new random duration for the next remark + nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second + } else { + s.sendData(".") + } + } + } +} + +// waitForCompletion waits for the statusUpdates goroutine to finish +func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool { + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + return true + case <-time.After(timeout): + return false + } +} + +func (s *statusResponseWriter) sendLine(line string) { + s.sendData(line + "\n") +} + +func (s *statusResponseWriter) sendData(data string) { + // Create the proper SSE JSON structure + type Delta struct { + ReasoningContent string `json:"reasoning_content"` + } + type Choice struct { + Delta Delta `json:"delta"` + } + type SSEMessage struct { + Choices []Choice `json:"choices"` + } + + msg := SSEMessage{ + Choices: []Choice{ + { + Delta: Delta{ + ReasoningContent: data, + }, + }, + }, + } + + jsonData, err := json.Marshal(msg) + if err != nil { + s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err) + return + } + + // Write SSE formatted data, panic if not able to write + _, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData) + if err != nil { + panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err)) + } + s.Flush() +} + +func (s *statusResponseWriter) Header() http.Header { + return s.writer.Header() +} + +func (s *statusResponseWriter) Write(data []byte) (int, error) { + return s.writer.Write(data) +} + +func (s *statusResponseWriter) WriteHeader(statusCode int) { + if s.hasWritten { + return + } + s.hasWritten = true + s.writer.WriteHeader(statusCode) + s.Flush() +} + +// Add Flush method +func (s *statusResponseWriter) Flush() { + if flusher, ok := s.writer.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 683d64b..1589341 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -25,6 +25,8 @@ const ( PROFILE_SPLIT_CHAR = ":" ) +type proxyCtxKey string + type ProxyManager struct { sync.Mutex @@ -555,6 +557,12 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes))) c.Request.ContentLength = int64(len(bodyBytes)) + // issue #366 extract values that downstream handlers may need + isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool() + ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming) + ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName) + c.Request = c.Request.WithContext(ctx) + if pm.metricsMonitor != nil && c.Request.Method == "POST" { if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))