Stream loading state when swapping models (#371)
Swapping models can take a long time and leave a lot of silence while the model is loading. Rather than silently load the model in the background, this PR allows llama-swap to send status updates in the reasoning_content of a streaming chat response. Fixes: #366
This commit is contained in:
@@ -129,6 +129,9 @@ type Config struct {
|
||||
|
||||
// hooks, see: #209
|
||||
Hooks HooksConfig `yaml:"hooks"`
|
||||
|
||||
// send loading state in reasoning
|
||||
SendLoadingState bool `yaml:"sendLoadingState"`
|
||||
}
|
||||
|
||||
func (c *Config) RealModelName(search string) (string, bool) {
|
||||
@@ -350,6 +353,13 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
)
|
||||
}
|
||||
|
||||
// if sendLoadingState is nil, set it to the global config value
|
||||
// see #366
|
||||
if modelConfig.SendLoadingState == nil {
|
||||
v := config.SendLoadingState // copy it
|
||||
modelConfig.SendLoadingState = &v
|
||||
}
|
||||
|
||||
config.Models[modelId] = modelConfig
|
||||
}
|
||||
|
||||
|
||||
@@ -160,6 +160,8 @@ groups:
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
StartPort: 5800,
|
||||
@@ -171,36 +173,41 @@ groups:
|
||||
Preload: []string{"model1", "model2"},
|
||||
},
|
||||
},
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Name: "Model 1",
|
||||
Description: "This is model 1",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
|
||||
@@ -152,44 +152,51 @@ groups:
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
modelLoadingState := false
|
||||
|
||||
expected := Config{
|
||||
LogLevel: "info",
|
||||
StartPort: 5800,
|
||||
Macros: MacroList{
|
||||
{"svr-path", "path/to/server"},
|
||||
},
|
||||
SendLoadingState: false,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8080",
|
||||
Aliases: []string{"m1", "model-one"},
|
||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||
CheckEndpoint: "/health",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model2": {
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
Cmd: "path/to/server --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"m2"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model3": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8081",
|
||||
Aliases: []string{"mthree"},
|
||||
Env: []string{},
|
||||
CheckEndpoint: "/",
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
"model4": {
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
Cmd: "path/to/cmd --arg1 one",
|
||||
CmdStop: "taskkill /f /t /pid ${PID}",
|
||||
Proxy: "http://localhost:8082",
|
||||
CheckEndpoint: "/",
|
||||
Aliases: []string{},
|
||||
Env: []string{},
|
||||
SendLoadingState: &modelLoadingState,
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
|
||||
@@ -35,6 +35,9 @@ type ModelConfig struct {
|
||||
// Metadata: see #264
|
||||
// Arbitrary metadata that can be exposed through the API
|
||||
Metadata map[string]any `yaml:"metadata"`
|
||||
|
||||
// override global setting
|
||||
SendLoadingState *bool `yaml:"sendLoadingState"`
|
||||
}
|
||||
|
||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
|
||||
@@ -50,5 +50,25 @@ models:
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestConfig_ModelSendLoadingState(t *testing.T) {
|
||||
content := `
|
||||
sendLoadingState: true
|
||||
models:
|
||||
model1:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
sendLoadingState: false
|
||||
model2:
|
||||
cmd: path/to/cmd --port ${PORT}
|
||||
`
|
||||
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, config.SendLoadingState)
|
||||
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
|
||||
assert.False(t, *config.Models["model1"].SendLoadingState)
|
||||
}
|
||||
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
|
||||
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||
}
|
||||
}
|
||||
|
||||
265
proxy/process.go
265
proxy/process.go
@@ -2,8 +2,10 @@ package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
@@ -462,6 +464,12 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
|
||||
}
|
||||
|
||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
if p.reverseProxy == nil {
|
||||
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
requestBeginTime := time.Now()
|
||||
var startDuration time.Duration
|
||||
|
||||
@@ -488,17 +496,44 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
p.inFlightRequests.Done()
|
||||
}()
|
||||
|
||||
// for #366
|
||||
// - extract streaming param from request context, should have been set by proxymanager
|
||||
var srw *statusResponseWriter
|
||||
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
|
||||
// start the process on demand
|
||||
if p.CurrentState() != StateReady {
|
||||
// start a goroutine to stream loading status messages into the response writer
|
||||
// add a sync so the streaming client only runs when the goroutine has exited
|
||||
|
||||
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
|
||||
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming {
|
||||
srw = newStatusResponseWriter(p, w)
|
||||
go srw.statusUpdates(swapCtx)
|
||||
} else {
|
||||
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
|
||||
}
|
||||
|
||||
beginStartTime := time.Now()
|
||||
if err := p.start(); err != nil {
|
||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
cancelLoadCtx()
|
||||
if srw != nil {
|
||||
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
|
||||
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
|
||||
// before closing the connection. Without this, the connection would close before
|
||||
// the goroutine can write its cleanup messages, causing incomplete SSE output.
|
||||
srw.waitForCompletion(100 * time.Millisecond)
|
||||
} else {
|
||||
http.Error(w, errstr, http.StatusBadGateway)
|
||||
}
|
||||
return
|
||||
}
|
||||
startDuration = time.Since(beginStartTime)
|
||||
}
|
||||
|
||||
// should trigger srw to stop sending loading events ...
|
||||
cancelLoadCtx()
|
||||
|
||||
// recover from http.ErrAbortHandler panics that can occur when the client
|
||||
// disconnects before the response is sent
|
||||
defer func() {
|
||||
@@ -511,10 +546,15 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
|
||||
if p.reverseProxy != nil {
|
||||
p.reverseProxy.ServeHTTP(w, r)
|
||||
if srw != nil {
|
||||
// Wait for the goroutine to finish writing its final messages
|
||||
const completionTimeout = 1 * time.Second
|
||||
if !srw.waitForCompletion(completionTimeout) {
|
||||
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
||||
}
|
||||
p.reverseProxy.ServeHTTP(srw, r)
|
||||
} else {
|
||||
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||
p.reverseProxy.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
totalTime := time.Since(requestBeginTime)
|
||||
@@ -600,3 +640,220 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var loadingRemarks = []string{
|
||||
"Still faster than your last standup meeting...",
|
||||
"Reticulating splines...",
|
||||
"Waking up the hamsters...",
|
||||
"Teaching the model manners...",
|
||||
"Convincing the GPU to participate...",
|
||||
"Loading weights (they're heavy)...",
|
||||
"Herding electrons...",
|
||||
"Compiling excuses for the delay...",
|
||||
"Downloading more RAM...",
|
||||
"Asking the model nicely to boot up...",
|
||||
"Bribing CUDA with cookies...",
|
||||
"Still loading (blame VRAM)...",
|
||||
"The model is fashionably late...",
|
||||
"Warming up those tensors...",
|
||||
"Making the neural net do push-ups...",
|
||||
"Your patience is appreciated (really)...",
|
||||
"Almost there (probably)...",
|
||||
"Loading like it's 1999...",
|
||||
"The model forgot where it put its keys...",
|
||||
"Quantum tunneling through layers...",
|
||||
"Negotiating with the PCIe bus...",
|
||||
"Defrosting frozen parameters...",
|
||||
"Teaching attention heads to focus...",
|
||||
"Running the matrix (slowly)...",
|
||||
"Untangling transformer blocks...",
|
||||
"Calibrating the flux capacitor...",
|
||||
"Spinning up the probability wheels...",
|
||||
"Waiting for the GPU to wake from its nap...",
|
||||
"Converting caffeine to compute...",
|
||||
"Allocating virtual patience...",
|
||||
"Performing arcane CUDA rituals...",
|
||||
"The model is stuck in traffic...",
|
||||
"Inflating embeddings...",
|
||||
"Summoning computational demons...",
|
||||
"Pleading with the OOM killer...",
|
||||
"Calculating the meaning of life (still at 42)...",
|
||||
"Training the training wheels...",
|
||||
"Optimizing the optimizer...",
|
||||
"Bootstrapping the bootstrapper...",
|
||||
"Loading loading screen...",
|
||||
"Processing processing logs...",
|
||||
"Buffering buffer overflow jokes...",
|
||||
"The model hit snooze...",
|
||||
"Debugging the debugger...",
|
||||
"Compiling the compiler...",
|
||||
"Parsing the parser (meta)...",
|
||||
"Tokenizing tokens...",
|
||||
"Encoding the encoder...",
|
||||
"Hashing hash browns...",
|
||||
"Forking spoons (not forks)...",
|
||||
"The model is contemplating existence...",
|
||||
"Transcending dimensional barriers...",
|
||||
"Invoking elder tensor gods...",
|
||||
"Unfurling probability clouds...",
|
||||
"Synchronizing parallel universes...",
|
||||
"The GPU is having second thoughts...",
|
||||
"Recalibrating reality matrices...",
|
||||
"Time is an illusion, loading doubly so...",
|
||||
"Convincing bits to flip themselves...",
|
||||
"The model is reading its own documentation...",
|
||||
}
|
||||
|
||||
type statusResponseWriter struct {
|
||||
hasWritten bool
|
||||
writer http.ResponseWriter
|
||||
process *Process
|
||||
wg sync.WaitGroup // Track goroutine completion
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
|
||||
s := &statusResponseWriter{
|
||||
writer: w,
|
||||
process: p,
|
||||
start: time.Now(),
|
||||
}
|
||||
|
||||
s.Header().Set("Content-Type", "text/event-stream") // SSE
|
||||
s.Header().Set("Cache-Control", "no-cache") // no-cache
|
||||
s.Header().Set("Connection", "keep-alive") // keep-alive
|
||||
s.WriteHeader(http.StatusOK) // send status code 200
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
|
||||
return s
|
||||
}
|
||||
|
||||
// statusUpdates sends status updates to the client while the model is loading
|
||||
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
|
||||
s.wg.Add(1)
|
||||
defer s.wg.Done()
|
||||
|
||||
defer func() {
|
||||
duration := time.Since(s.start)
|
||||
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
|
||||
s.sendLine("━━━━━")
|
||||
s.sendLine(" ")
|
||||
}()
|
||||
|
||||
// Create a shuffled copy of loadingRemarks
|
||||
remarks := make([]string, len(loadingRemarks))
|
||||
copy(remarks, loadingRemarks)
|
||||
rand.Shuffle(len(remarks), func(i, j int) {
|
||||
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||
})
|
||||
ri := 0
|
||||
|
||||
// Pick a random duration to send a remark
|
||||
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||
lastRemarkTime := time.Now()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if s.process.CurrentState() == StateReady {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it's time for a snarky remark
|
||||
if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||
remark := remarks[ri%len(remarks)]
|
||||
ri++
|
||||
s.sendLine(fmt.Sprintf("\n%s", remark))
|
||||
lastRemarkTime = time.Now()
|
||||
// Pick a new random duration for the next remark
|
||||
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||
} else {
|
||||
s.sendData(".")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCompletion waits for the statusUpdates goroutine to finish
|
||||
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendLine(line string) {
|
||||
s.sendData(line + "\n")
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) sendData(data string) {
|
||||
// Create the proper SSE JSON structure
|
||||
type Delta struct {
|
||||
ReasoningContent string `json:"reasoning_content"`
|
||||
}
|
||||
type Choice struct {
|
||||
Delta Delta `json:"delta"`
|
||||
}
|
||||
type SSEMessage struct {
|
||||
Choices []Choice `json:"choices"`
|
||||
}
|
||||
|
||||
msg := SSEMessage{
|
||||
Choices: []Choice{
|
||||
{
|
||||
Delta: Delta{
|
||||
ReasoningContent: data,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Write SSE formatted data, panic if not able to write
|
||||
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
|
||||
}
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Header() http.Header {
|
||||
return s.writer.Header()
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) Write(data []byte) (int, error) {
|
||||
return s.writer.Write(data)
|
||||
}
|
||||
|
||||
func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||
if s.hasWritten {
|
||||
return
|
||||
}
|
||||
s.hasWritten = true
|
||||
s.writer.WriteHeader(statusCode)
|
||||
s.Flush()
|
||||
}
|
||||
|
||||
// Add Flush method
|
||||
func (s *statusResponseWriter) Flush() {
|
||||
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ const (
|
||||
PROFILE_SPLIT_CHAR = ":"
|
||||
)
|
||||
|
||||
type proxyCtxKey string
|
||||
|
||||
type ProxyManager struct {
|
||||
sync.Mutex
|
||||
|
||||
@@ -555,6 +557,12 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||
c.Request.ContentLength = int64(len(bodyBytes))
|
||||
|
||||
// issue #366 extract values that downstream handlers may need
|
||||
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
|
||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||
|
||||
Reference in New Issue
Block a user