Stream loading state when swapping models (#371)
Swapping models can take a long time and leave a lot of silence while the model is loading. Rather than silently load the model in the background, this PR allows llama-swap to send status updates in the reasoning_content of a streaming chat response. Fixes: #366
This commit is contained in:
265
proxy/process.go
265
proxy/process.go
@@ -2,8 +2,10 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
@@ -462,6 +464,12 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if p.reverseProxy == nil {
|
||||
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
requestBeginTime := time.Now()
|
||||
var startDuration time.Duration
|
||||
|
||||
@@ -488,17 +496,44 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
// for #366
|
||||
// - extract streaming param from request context, should have been set by proxymanager
|
||||
var srw *statusResponseWriter
|
||||
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
// start a goroutine to stream loading status messages into the response writer
|
||||
// add a sync so the streaming client only runs when the goroutine has exited
|
||||
|
||||
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
|
||||
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming {
|
||||
srw = newStatusResponseWriter(p, w)
|
||||
go srw.statusUpdates(swapCtx)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
|
||||
}
|
||||
|
||||
beginStartTime := time.Now()
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
cancelLoadCtx()
|
||||
if srw != nil {
|
||||
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
|
||||
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
|
||||
// before closing the connection. Without this, the connection would close before
|
||||
// the goroutine can write its cleanup messages, causing incomplete SSE output.
|
||||
srw.waitForCompletion(100 * time.Millisecond)
|
||||
} else {
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
}
|
||||
return
|
||||
}
|
||||
startDuration = time.Since(beginStartTime)
|
||||
}
|
||||
|
||||
// should trigger srw to stop sending loading events ...
|
||||
cancelLoadCtx()
|
||||
|
||||
// recover from http.ErrAbortHandler panics that can occur when the client
|
||||
// disconnects before the response is sent
|
||||
defer func() {
|
||||
@@ -511,10 +546,15 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
|
||||
if p.reverseProxy != nil {
|
||||
p.reverseProxy.ServeHTTP(w, r)
|
||||
if srw != nil {
|
||||
// Wait for the goroutine to finish writing its final messages
|
||||
const completionTimeout = 1 * time.Second
|
||||
if !srw.waitForCompletion(completionTimeout) {
|
||||
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
||||
}
|
||||
p.reverseProxy.ServeHTTP(srw, r)
|
||||
} else {
|
||||
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||
p.reverseProxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
totalTime := time.Since(requestBeginTime)
|
||||
@@ -600,3 +640,220 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting...",
|
||||
"Reticulating splines...",
|
||||
"Waking up the hamsters...",
|
||||
"Teaching the model manners...",
|
||||
"Convincing the GPU to participate...",
|
||||
"Loading weights (they're heavy)...",
|
||||
"Herding electrons...",
|
||||
"Compiling excuses for the delay...",
|
||||
"Downloading more RAM...",
|
||||
"Asking the model nicely to boot up...",
|
||||
"Bribing CUDA with cookies...",
|
||||
"Still loading (blame VRAM)...",
|
||||
"The model is fashionably late...",
|
||||
"Warming up those tensors...",
|
||||
"Making the neural net do push-ups...",
|
||||
"Your patience is appreciated (really)...",
|
||||
"Almost there (probably)...",
|
||||
"Loading like it's 1999...",
|
||||
"The model forgot where it put its keys...",
|
||||
"Quantum tunneling through layers...",
|
||||
"Negotiating with the PCIe bus...",
|
||||
"Defrosting frozen parameters...",
|
||||
"Teaching attention heads to focus...",
|
||||
"Running the matrix (slowly)...",
|
||||
"Untangling transformer blocks...",
|
||||
"Calibrating the flux capacitor...",
|
||||
"Spinning up the probability wheels...",
|
||||
"Waiting for the GPU to wake from its nap...",
|
||||
"Converting caffeine to compute...",
|
||||
"Allocating virtual patience...",
|
||||
"Performing arcane CUDA rituals...",
|
||||
"The model is stuck in traffic...",
|
||||
"Inflating embeddings...",
|
||||
"Summoning computational demons...",
|
||||
"Pleading with the OOM killer...",
|
||||
"Calculating the meaning of life (still at 42)...",
|
||||
"Training the training wheels...",
|
||||
"Optimizing the optimizer...",
|
||||
"Bootstrapping the bootstrapper...",
|
||||
"Loading loading screen...",
|
||||
"Processing processing logs...",
|
||||
"Buffering buffer overflow jokes...",
|
||||
"The model hit snooze...",
|
||||
"Debugging the debugger...",
|
||||
"Compiling the compiler...",
|
||||
"Parsing the parser (meta)...",
|
||||
"Tokenizing tokens...",
|
||||
"Encoding the encoder...",
|
||||
"Hashing hash browns...",
|
||||
"Forking spoons (not forks)...",
|
||||
"The model is contemplating existence...",
|
||||
"Transcending dimensional barriers...",
|
||||
"Invoking elder tensor gods...",
|
||||
"Unfurling probability clouds...",
|
||||
"Synchronizing parallel universes...",
|
||||
"The GPU is having second thoughts...",
|
||||
"Recalibrating reality matrices...",
|
||||
"Time is an illusion, loading doubly so...",
|
||||
"Convincing bits to flip themselves...",
|
||||
"The model is reading its own documentation...",
|
||||
}
|
||||
|
||||
type statusResponseWriter struct {
|
||||
hasWritten bool
|
||||
writer http.ResponseWriter
|
||||
process *Process
|
||||
wg sync.WaitGroup // Track goroutine completion
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
|
||||
s := &statusResponseWriter{
|
||||
writer: w,
|
||||
process: p,
|
||||
start: time.Now(),
|
||||
}
|
||||
|
||||
s.Header().Set("Content-Type", "text/event-stream") // SSE
|
||||
s.Header().Set("Cache-Control", "no-cache") // no-cache
|
||||
s.Header().Set("Connection", "keep-alive") // keep-alive
|
||||
s.WriteHeader(http.StatusOK) // send status code 200
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
|
||||
return s
|
||||
}
|
||||
|
||||
// statusUpdates sends status updates to the client while the model is loading
|
||||
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
|
||||
s.wg.Add(1)
|
||||
defer s.wg.Done()
|
||||
|
||||
defer func() {
|
||||
duration := time.Since(s.start)
|
||||
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(" ")
|
||||
}()
|
||||
|
||||
// Create a shuffled copy of loadingRemarks
|
||||
remarks := make([]string, len(loadingRemarks))
|
||||
copy(remarks, loadingRemarks)
|
||||
rand.Shuffle(len(remarks), func(i, j int) {
|
||||
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||
})
|
||||
ri := 0
|
||||
|
||||
// Pick a random duration to send a remark
|
||||
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||
lastRemarkTime := time.Now()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if s.process.CurrentState() == StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it's time for a snarky remark
|
||||
if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||
remark := remarks[ri%len(remarks)]
|
||||
ri++
|
||||
s.sendLine(fmt.Sprintf("\n%s", remark))
|
||||
lastRemarkTime = time.Now()
|
||||
// Pick a new random duration for the next remark
|
||||
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||
} else {
|
||||
s.sendData(".")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCompletion waits for the statusUpdates goroutine to finish
|
||||
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendLine(line string) {
|
||||
s.sendData(line + "\n")
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendData(data string) {
|
||||
// Create the proper SSE JSON structure
|
||||
type Delta struct {
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
}
|
||||
type Choice struct {
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
type SSEMessage struct {
|
||||
Choices []Choice `json:"choices"`
|
||||
}
|
||||
|
||||
msg := SSEMessage{
|
||||
Choices: []Choice{
|
||||
{
|
||||
Delta: Delta{
|
||||
ReasoningContent: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write SSE formatted data, panic if not able to write
|
||||
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
|
||||
}
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Header() http.Header {
|
||||
return s.writer.Header()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Write(data []byte) (int, error) {
|
||||
return s.writer.Write(data)
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||
if s.hasWritten {
|
||||
return
|
||||
}
|
||||
s.hasWritten = true
|
||||
s.writer.WriteHeader(statusCode)
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
// Add Flush method
|
||||
func (s *statusResponseWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user