Stream loading state when swapping models (#371)
Swapping models can take a long time and leave a lot of silence while the model is loading. Rather than silently load the model in the background, this PR allows llama-swap to send status updates in the reasoning_content of a streaming chat response. Fixes: #366
This commit is contained in:
@@ -35,6 +35,14 @@ metricsMaxInMemory: 1000
|
|||||||
# - it is automatically incremented for every model that uses it
|
# - it is automatically incremented for every model that uses it
|
||||||
startPort: 10001
|
startPort: 10001
|
||||||
|
|
||||||
|
# sendLoadingState: inject loading status updates into the reasoning (thinking)
|
||||||
|
# field
|
||||||
|
# - optional, default: false
|
||||||
|
# - when true, a stream of loading messages will be sent to the client in the
|
||||||
|
# reasoning field so chat UIs can show that loading is in progress.
|
||||||
|
# - see #366 for more details
|
||||||
|
sendLoadingState: true
|
||||||
|
|
||||||
# macros: a dictionary of string substitutions
|
# macros: a dictionary of string substitutions
|
||||||
# - optional, default: empty dictionary
|
# - optional, default: empty dictionary
|
||||||
# - macros are reusable snippets
|
# - macros are reusable snippets
|
||||||
@@ -184,6 +192,10 @@ models:
|
|||||||
# - recommended to be omitted and the default used
|
# - recommended to be omitted and the default used
|
||||||
concurrencyLimit: 0
|
concurrencyLimit: 0
|
||||||
|
|
||||||
|
# sendLoadingState: overrides the global sendLoadingState setting for this model
|
||||||
|
# - optional, default: undefined (use global setting)
|
||||||
|
sendLoadingState: false
|
||||||
|
|
||||||
# Unlisted model example:
|
# Unlisted model example:
|
||||||
"qwen-unlisted":
|
"qwen-unlisted":
|
||||||
# unlisted: boolean, true or false
|
# unlisted: boolean, true or false
|
||||||
|
|||||||
@@ -129,6 +129,9 @@ type Config struct {
|
|||||||
|
|
||||||
// hooks, see: #209
|
// hooks, see: #209
|
||||||
Hooks HooksConfig `yaml:"hooks"`
|
Hooks HooksConfig `yaml:"hooks"`
|
||||||
|
|
||||||
|
// send loading state in reasoning
|
||||||
|
SendLoadingState bool `yaml:"sendLoadingState"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Config) RealModelName(search string) (string, bool) {
|
func (c *Config) RealModelName(search string) (string, bool) {
|
||||||
@@ -350,6 +353,13 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if sendLoadingState is nil, set it to the global config value
|
||||||
|
// see #366
|
||||||
|
if modelConfig.SendLoadingState == nil {
|
||||||
|
v := config.SendLoadingState // copy it
|
||||||
|
modelConfig.SendLoadingState = &v
|
||||||
|
}
|
||||||
|
|
||||||
config.Models[modelId] = modelConfig
|
config.Models[modelId] = modelConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -160,6 +160,8 @@ groups:
|
|||||||
t.Fatalf("Failed to load config: %v", err)
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelLoadingState := false
|
||||||
|
|
||||||
expected := Config{
|
expected := Config{
|
||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
StartPort: 5800,
|
StartPort: 5800,
|
||||||
@@ -171,6 +173,7 @@ groups:
|
|||||||
Preload: []string{"model1", "model2"},
|
Preload: []string{"model1", "model2"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
SendLoadingState: false,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -180,6 +183,7 @@ groups:
|
|||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
Name: "Model 1",
|
Name: "Model 1",
|
||||||
Description: "This is model 1",
|
Description: "This is model 1",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
},
|
},
|
||||||
"model2": {
|
"model2": {
|
||||||
Cmd: "path/to/server --arg1 one",
|
Cmd: "path/to/server --arg1 one",
|
||||||
@@ -187,6 +191,7 @@ groups:
|
|||||||
Aliases: []string{"m2"},
|
Aliases: []string{"m2"},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
},
|
},
|
||||||
"model3": {
|
"model3": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -194,6 +199,7 @@ groups:
|
|||||||
Aliases: []string{"mthree"},
|
Aliases: []string{"mthree"},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
},
|
},
|
||||||
"model4": {
|
"model4": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -201,6 +207,7 @@ groups:
|
|||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
Aliases: []string{},
|
Aliases: []string{},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
|
|||||||
@@ -152,12 +152,15 @@ groups:
|
|||||||
t.Fatalf("Failed to load config: %v", err)
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
modelLoadingState := false
|
||||||
|
|
||||||
expected := Config{
|
expected := Config{
|
||||||
LogLevel: "info",
|
LogLevel: "info",
|
||||||
StartPort: 5800,
|
StartPort: 5800,
|
||||||
Macros: MacroList{
|
Macros: MacroList{
|
||||||
{"svr-path", "path/to/server"},
|
{"svr-path", "path/to/server"},
|
||||||
},
|
},
|
||||||
|
SendLoadingState: false,
|
||||||
Models: map[string]ModelConfig{
|
Models: map[string]ModelConfig{
|
||||||
"model1": {
|
"model1": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -166,6 +169,7 @@ groups:
|
|||||||
Aliases: []string{"m1", "model-one"},
|
Aliases: []string{"m1", "model-one"},
|
||||||
Env: []string{"VAR1=value1", "VAR2=value2"},
|
Env: []string{"VAR1=value1", "VAR2=value2"},
|
||||||
CheckEndpoint: "/health",
|
CheckEndpoint: "/health",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
},
|
},
|
||||||
"model2": {
|
"model2": {
|
||||||
Cmd: "path/to/server --arg1 one",
|
Cmd: "path/to/server --arg1 one",
|
||||||
@@ -174,6 +178,7 @@ groups:
|
|||||||
Aliases: []string{"m2"},
|
Aliases: []string{"m2"},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
},
|
},
|
||||||
"model3": {
|
"model3": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -182,6 +187,7 @@ groups:
|
|||||||
Aliases: []string{"mthree"},
|
Aliases: []string{"mthree"},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
},
|
},
|
||||||
"model4": {
|
"model4": {
|
||||||
Cmd: "path/to/cmd --arg1 one",
|
Cmd: "path/to/cmd --arg1 one",
|
||||||
@@ -190,6 +196,7 @@ groups:
|
|||||||
CheckEndpoint: "/",
|
CheckEndpoint: "/",
|
||||||
Aliases: []string{},
|
Aliases: []string{},
|
||||||
Env: []string{},
|
Env: []string{},
|
||||||
|
SendLoadingState: &modelLoadingState,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
|
|||||||
@@ -35,6 +35,9 @@ type ModelConfig struct {
|
|||||||
// Metadata: see #264
|
// Metadata: see #264
|
||||||
// Arbitrary metadata that can be exposed through the API
|
// Arbitrary metadata that can be exposed through the API
|
||||||
Metadata map[string]any `yaml:"metadata"`
|
Metadata map[string]any `yaml:"metadata"`
|
||||||
|
|
||||||
|
// override global setting
|
||||||
|
SendLoadingState *bool `yaml:"sendLoadingState"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
func (m *ModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
|||||||
@@ -50,5 +50,25 @@ models:
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig_ModelSendLoadingState(t *testing.T) {
|
||||||
|
content := `
|
||||||
|
sendLoadingState: true
|
||||||
|
models:
|
||||||
|
model1:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
sendLoadingState: false
|
||||||
|
model2:
|
||||||
|
cmd: path/to/cmd --port ${PORT}
|
||||||
|
`
|
||||||
|
config, err := LoadConfigFromReader(strings.NewReader(content))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.True(t, config.SendLoadingState)
|
||||||
|
if assert.NotNil(t, config.Models["model1"].SendLoadingState) {
|
||||||
|
assert.False(t, *config.Models["model1"].SendLoadingState)
|
||||||
|
}
|
||||||
|
if assert.NotNil(t, config.Models["model2"].SendLoadingState) {
|
||||||
|
assert.True(t, *config.Models["model2"].SendLoadingState)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
263
proxy/process.go
263
proxy/process.go
@@ -2,8 +2,10 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
"net/http/httputil"
|
||||||
@@ -462,6 +464,12 @@ func (p *Process) checkHealthEndpoint(healthURL string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
|
if p.reverseProxy == nil {
|
||||||
|
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
requestBeginTime := time.Now()
|
requestBeginTime := time.Now()
|
||||||
var startDuration time.Duration
|
var startDuration time.Duration
|
||||||
|
|
||||||
@@ -488,17 +496,44 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
p.inFlightRequests.Done()
|
p.inFlightRequests.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// for #366
|
||||||
|
// - extract streaming param from request context, should have been set by proxymanager
|
||||||
|
var srw *statusResponseWriter
|
||||||
|
swapCtx, cancelLoadCtx := context.WithCancel(r.Context())
|
||||||
// start the process on demand
|
// start the process on demand
|
||||||
if p.CurrentState() != StateReady {
|
if p.CurrentState() != StateReady {
|
||||||
|
// start a goroutine to stream loading status messages into the response writer
|
||||||
|
// add a sync so the streaming client only runs when the goroutine has exited
|
||||||
|
|
||||||
|
isStreaming, _ := r.Context().Value(proxyCtxKey("streaming")).(bool)
|
||||||
|
if p.config.SendLoadingState != nil && *p.config.SendLoadingState && isStreaming {
|
||||||
|
srw = newStatusResponseWriter(p, w)
|
||||||
|
go srw.statusUpdates(swapCtx)
|
||||||
|
} else {
|
||||||
|
p.proxyLogger.Debugf("<%s> SendLoadingState is nil or false, not streaming loading state", p.ID)
|
||||||
|
}
|
||||||
|
|
||||||
beginStartTime := time.Now()
|
beginStartTime := time.Now()
|
||||||
if err := p.start(); err != nil {
|
if err := p.start(); err != nil {
|
||||||
errstr := fmt.Sprintf("unable to start process: %s", err)
|
errstr := fmt.Sprintf("unable to start process: %s", err)
|
||||||
|
cancelLoadCtx()
|
||||||
|
if srw != nil {
|
||||||
|
srw.sendData(fmt.Sprintf("Unable to swap model err: %s\n", errstr))
|
||||||
|
// Wait for statusUpdates goroutine to finish writing its deferred "Done!" messages
|
||||||
|
// before closing the connection. Without this, the connection would close before
|
||||||
|
// the goroutine can write its cleanup messages, causing incomplete SSE output.
|
||||||
|
srw.waitForCompletion(100 * time.Millisecond)
|
||||||
|
} else {
|
||||||
http.Error(w, errstr, http.StatusBadGateway)
|
http.Error(w, errstr, http.StatusBadGateway)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
startDuration = time.Since(beginStartTime)
|
startDuration = time.Since(beginStartTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// should trigger srw to stop sending loading events ...
|
||||||
|
cancelLoadCtx()
|
||||||
|
|
||||||
// recover from http.ErrAbortHandler panics that can occur when the client
|
// recover from http.ErrAbortHandler panics that can occur when the client
|
||||||
// disconnects before the response is sent
|
// disconnects before the response is sent
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -511,10 +546,15 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if p.reverseProxy != nil {
|
if srw != nil {
|
||||||
p.reverseProxy.ServeHTTP(w, r)
|
// Wait for the goroutine to finish writing its final messages
|
||||||
|
const completionTimeout = 1 * time.Second
|
||||||
|
if !srw.waitForCompletion(completionTimeout) {
|
||||||
|
p.proxyLogger.Warnf("<%s> status updates goroutine did not complete within %v, proceeding with proxy request", p.ID, completionTimeout)
|
||||||
|
}
|
||||||
|
p.reverseProxy.ServeHTTP(srw, r)
|
||||||
} else {
|
} else {
|
||||||
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
p.reverseProxy.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
totalTime := time.Since(requestBeginTime)
|
totalTime := time.Since(requestBeginTime)
|
||||||
@@ -600,3 +640,220 @@ func (p *Process) cmdStopUpstreamProcess() error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var loadingRemarks = []string{
|
||||||
|
"Still faster than your last standup meeting...",
|
||||||
|
"Reticulating splines...",
|
||||||
|
"Waking up the hamsters...",
|
||||||
|
"Teaching the model manners...",
|
||||||
|
"Convincing the GPU to participate...",
|
||||||
|
"Loading weights (they're heavy)...",
|
||||||
|
"Herding electrons...",
|
||||||
|
"Compiling excuses for the delay...",
|
||||||
|
"Downloading more RAM...",
|
||||||
|
"Asking the model nicely to boot up...",
|
||||||
|
"Bribing CUDA with cookies...",
|
||||||
|
"Still loading (blame VRAM)...",
|
||||||
|
"The model is fashionably late...",
|
||||||
|
"Warming up those tensors...",
|
||||||
|
"Making the neural net do push-ups...",
|
||||||
|
"Your patience is appreciated (really)...",
|
||||||
|
"Almost there (probably)...",
|
||||||
|
"Loading like it's 1999...",
|
||||||
|
"The model forgot where it put its keys...",
|
||||||
|
"Quantum tunneling through layers...",
|
||||||
|
"Negotiating with the PCIe bus...",
|
||||||
|
"Defrosting frozen parameters...",
|
||||||
|
"Teaching attention heads to focus...",
|
||||||
|
"Running the matrix (slowly)...",
|
||||||
|
"Untangling transformer blocks...",
|
||||||
|
"Calibrating the flux capacitor...",
|
||||||
|
"Spinning up the probability wheels...",
|
||||||
|
"Waiting for the GPU to wake from its nap...",
|
||||||
|
"Converting caffeine to compute...",
|
||||||
|
"Allocating virtual patience...",
|
||||||
|
"Performing arcane CUDA rituals...",
|
||||||
|
"The model is stuck in traffic...",
|
||||||
|
"Inflating embeddings...",
|
||||||
|
"Summoning computational demons...",
|
||||||
|
"Pleading with the OOM killer...",
|
||||||
|
"Calculating the meaning of life (still at 42)...",
|
||||||
|
"Training the training wheels...",
|
||||||
|
"Optimizing the optimizer...",
|
||||||
|
"Bootstrapping the bootstrapper...",
|
||||||
|
"Loading loading screen...",
|
||||||
|
"Processing processing logs...",
|
||||||
|
"Buffering buffer overflow jokes...",
|
||||||
|
"The model hit snooze...",
|
||||||
|
"Debugging the debugger...",
|
||||||
|
"Compiling the compiler...",
|
||||||
|
"Parsing the parser (meta)...",
|
||||||
|
"Tokenizing tokens...",
|
||||||
|
"Encoding the encoder...",
|
||||||
|
"Hashing hash browns...",
|
||||||
|
"Forking spoons (not forks)...",
|
||||||
|
"The model is contemplating existence...",
|
||||||
|
"Transcending dimensional barriers...",
|
||||||
|
"Invoking elder tensor gods...",
|
||||||
|
"Unfurling probability clouds...",
|
||||||
|
"Synchronizing parallel universes...",
|
||||||
|
"The GPU is having second thoughts...",
|
||||||
|
"Recalibrating reality matrices...",
|
||||||
|
"Time is an illusion, loading doubly so...",
|
||||||
|
"Convincing bits to flip themselves...",
|
||||||
|
"The model is reading its own documentation...",
|
||||||
|
}
|
||||||
|
|
||||||
|
type statusResponseWriter struct {
|
||||||
|
hasWritten bool
|
||||||
|
writer http.ResponseWriter
|
||||||
|
process *Process
|
||||||
|
wg sync.WaitGroup // Track goroutine completion
|
||||||
|
start time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStatusResponseWriter(p *Process, w http.ResponseWriter) *statusResponseWriter {
|
||||||
|
s := &statusResponseWriter{
|
||||||
|
writer: w,
|
||||||
|
process: p,
|
||||||
|
start: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
s.Header().Set("Content-Type", "text/event-stream") // SSE
|
||||||
|
s.Header().Set("Cache-Control", "no-cache") // no-cache
|
||||||
|
s.Header().Set("Connection", "keep-alive") // keep-alive
|
||||||
|
s.WriteHeader(http.StatusOK) // send status code 200
|
||||||
|
s.sendLine("━━━━━")
|
||||||
|
s.sendLine(fmt.Sprintf("llama-swap loading model: %s", p.ID))
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// statusUpdates sends status updates to the client while the model is loading
|
||||||
|
func (s *statusResponseWriter) statusUpdates(ctx context.Context) {
|
||||||
|
s.wg.Add(1)
|
||||||
|
defer s.wg.Done()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
duration := time.Since(s.start)
|
||||||
|
s.sendLine(fmt.Sprintf("\nDone! (%.2fs)", duration.Seconds()))
|
||||||
|
s.sendLine("━━━━━")
|
||||||
|
s.sendLine(" ")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create a shuffled copy of loadingRemarks
|
||||||
|
remarks := make([]string, len(loadingRemarks))
|
||||||
|
copy(remarks, loadingRemarks)
|
||||||
|
rand.Shuffle(len(remarks), func(i, j int) {
|
||||||
|
remarks[i], remarks[j] = remarks[j], remarks[i]
|
||||||
|
})
|
||||||
|
ri := 0
|
||||||
|
|
||||||
|
// Pick a random duration to send a remark
|
||||||
|
nextRemarkIn := time.Duration(2+rand.Intn(4)) * time.Second
|
||||||
|
lastRemarkTime := time.Now()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop() // Ensure ticker is stopped to prevent resource leak
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if s.process.CurrentState() == StateReady {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if it's time for a snarky remark
|
||||||
|
if time.Since(lastRemarkTime) >= nextRemarkIn {
|
||||||
|
remark := remarks[ri%len(remarks)]
|
||||||
|
ri++
|
||||||
|
s.sendLine(fmt.Sprintf("\n%s", remark))
|
||||||
|
lastRemarkTime = time.Now()
|
||||||
|
// Pick a new random duration for the next remark
|
||||||
|
nextRemarkIn = time.Duration(5+rand.Intn(5)) * time.Second
|
||||||
|
} else {
|
||||||
|
s.sendData(".")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForCompletion waits for the statusUpdates goroutine to finish
|
||||||
|
func (s *statusResponseWriter) waitForCompletion(timeout time.Duration) bool {
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return true
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) sendLine(line string) {
|
||||||
|
s.sendData(line + "\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) sendData(data string) {
|
||||||
|
// Create the proper SSE JSON structure
|
||||||
|
type Delta struct {
|
||||||
|
ReasoningContent string `json:"reasoning_content"`
|
||||||
|
}
|
||||||
|
type Choice struct {
|
||||||
|
Delta Delta `json:"delta"`
|
||||||
|
}
|
||||||
|
type SSEMessage struct {
|
||||||
|
Choices []Choice `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := SSEMessage{
|
||||||
|
Choices: []Choice{
|
||||||
|
{
|
||||||
|
Delta: Delta{
|
||||||
|
ReasoningContent: data,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg)
|
||||||
|
if err != nil {
|
||||||
|
s.process.proxyLogger.Errorf("<%s> Failed to marshal SSE message: %v", s.process.ID, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write SSE formatted data, panic if not able to write
|
||||||
|
_, err = fmt.Fprintf(s.writer, "data: %s\n\n", jsonData)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("<%s> Failed to write SSE data: %v", s.process.ID, err))
|
||||||
|
}
|
||||||
|
s.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) Header() http.Header {
|
||||||
|
return s.writer.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) Write(data []byte) (int, error) {
|
||||||
|
return s.writer.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *statusResponseWriter) WriteHeader(statusCode int) {
|
||||||
|
if s.hasWritten {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.hasWritten = true
|
||||||
|
s.writer.WriteHeader(statusCode)
|
||||||
|
s.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add Flush method
|
||||||
|
func (s *statusResponseWriter) Flush() {
|
||||||
|
if flusher, ok := s.writer.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ const (
|
|||||||
PROFILE_SPLIT_CHAR = ":"
|
PROFILE_SPLIT_CHAR = ":"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type proxyCtxKey string
|
||||||
|
|
||||||
type ProxyManager struct {
|
type ProxyManager struct {
|
||||||
sync.Mutex
|
sync.Mutex
|
||||||
|
|
||||||
@@ -555,6 +557,12 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
c.Request.ContentLength = int64(len(bodyBytes))
|
c.Request.ContentLength = int64(len(bodyBytes))
|
||||||
|
|
||||||
|
// issue #366 extract values that downstream handlers may need
|
||||||
|
isStreaming := gjson.GetBytes(bodyBytes, "stream").Bool()
|
||||||
|
ctx := context.WithValue(c.Request.Context(), proxyCtxKey("streaming"), isStreaming)
|
||||||
|
ctx = context.WithValue(ctx, proxyCtxKey("model"), realModelName)
|
||||||
|
c.Request = c.Request.WithContext(ctx)
|
||||||
|
|
||||||
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
|
|||||||
Reference in New Issue
Block a user