package proxy import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "net/url" "os/exec" "strings" "sync" "syscall" "time" ) type ProxyManager struct { sync.Mutex config *Config currentCmd *exec.Cmd currentConfig ModelConfig logMonitor *LogMonitor } func New(config *Config) *ProxyManager { return &ProxyManager{config: config, logMonitor: NewLogMonitor()} } func (pm *ProxyManager) HandleFunc(w http.ResponseWriter, r *http.Request) { // https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md#api-endpoints if r.URL.Path == "/v1/chat/completions" { // extracts the `model` from json body pm.proxyChatRequest(w, r) } else if r.URL.Path == "/v1/models" { pm.listModels(w, r) } else if r.URL.Path == "/logs" { pm.streamLogs(w, r) } else { pm.proxyRequest(w, r) } } func (pm *ProxyManager) streamLogs(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("X-Content-Type-Options", "nosniff") ch := pm.logMonitor.Subscribe() defer pm.logMonitor.Unsubscribe(ch) notify := r.Context().Done() flusher, ok := w.(http.Flusher) if !ok { http.Error(w, "Streaming unsupported", http.StatusInternalServerError) return } skipHistory := r.URL.Query().Has("skip") if !skipHistory { // Send history first history := pm.logMonitor.GetHistory() if len(history) != 0 { w.Write(history) flusher.Flush() } } if !r.URL.Query().Has("stream") { return } // Stream new logs for { select { case msg := <-ch: w.Write(msg) flusher.Flush() case <-notify: return } } } func (pm *ProxyManager) listModels(w http.ResponseWriter, _ *http.Request) { data := []interface{}{} for id := range pm.config.Models { data = append(data, map[string]interface{}{ "id": id, "object": "model", "created": time.Now().Unix(), "owned_by": "llama-swap", }) } // Set the Content-Type header to application/json w.Header().Set("Content-Type", "application/json") // Encode the data as JSON and write it to the response writer if err := json.NewEncoder(w).Encode(map[string]interface{}{"data": data}); err != nil { http.Error(w, "Error encoding JSON", http.StatusInternalServerError) return } } func (pm *ProxyManager) swapModel(requestedModel string) error { pm.Lock() defer pm.Unlock() // find the model configuration matching requestedModel modelConfig, found := pm.config.FindConfig(requestedModel) if !found { return fmt.Errorf("could not find configuration for %s", requestedModel) } // no need to swap llama.cpp instances if pm.currentConfig.Cmd == modelConfig.Cmd { return nil } // kill the current running one to swap it if pm.currentCmd != nil { pm.currentCmd.Process.Signal(syscall.SIGTERM) // wait for it to end pm.currentCmd.Process.Wait() } pm.currentConfig = modelConfig args, err := modelConfig.SanitizedCommand() if err != nil { return fmt.Errorf("unable to get sanitized command: %v", err) } cmd := exec.Command(args[0], args[1:]...) // logMonitor only writes to stdout // so the upstream's stderr will go to os.Stdout cmd.Stdout = pm.logMonitor cmd.Stderr = pm.logMonitor cmd.Env = modelConfig.Env err = cmd.Start() if err != nil { return err } pm.currentCmd = cmd // watch for the command to exist cmdCtx, cancel := context.WithCancelCause(context.Background()) // monitor the command's exist status go func() { err := cmd.Wait() if err != nil { cancel(fmt.Errorf("command [%s] %s", strings.Join(cmd.Args, " "), err.Error())) } else { cancel(nil) } }() // wait for checkHealthEndpoint if err := pm.checkHealthEndpoint(cmdCtx); err != nil { return err } return nil } func (pm *ProxyManager) checkHealthEndpoint(cmdCtx context.Context) error { if pm.currentConfig.Proxy == "" { return fmt.Errorf("no upstream available to check /health") } checkEndpoint := strings.TrimSpace(pm.currentConfig.CheckEndpoint) if checkEndpoint == "none" { return nil } // keep default behaviour if checkEndpoint == "" { checkEndpoint = "/health" } proxyTo := pm.currentConfig.Proxy maxDuration := time.Second * time.Duration(pm.config.HealthCheckTimeout) healthURL, err := url.JoinPath(proxyTo, checkEndpoint) if err != nil { return fmt.Errorf("failed to create health url with with %s and path %s", proxyTo, checkEndpoint) } client := &http.Client{} startTime := time.Now() for { req, err := http.NewRequest("GET", healthURL, nil) if err != nil { return err } ctx, cancel := context.WithTimeout(cmdCtx, 250*time.Millisecond) defer cancel() req = req.WithContext(ctx) resp, err := client.Do(req) ttl := (maxDuration - time.Since(startTime)).Seconds() if err != nil { // check if the context was cancelled select { case <-ctx.Done(): return context.Cause(ctx) default: } // wait a bit longer for TCP connection issues if strings.Contains(err.Error(), "connection refused") { fmt.Fprintf(pm.logMonitor, "Connection refused on %s, ttl %.0fs\n", healthURL, ttl) time.Sleep(5 * time.Second) } else { time.Sleep(time.Second) } if ttl < 0 { return fmt.Errorf("failed to check health from: %s", healthURL) } continue } defer resp.Body.Close() if resp.StatusCode == http.StatusOK { return nil } if ttl < 0 { return fmt.Errorf("failed to check health from: %s", healthURL) } time.Sleep(time.Second) } } func (pm *ProxyManager) proxyChatRequest(w http.ResponseWriter, r *http.Request) { bodyBytes, err := io.ReadAll(r.Body) if err != nil { http.Error(w, "Invalid JSON", http.StatusBadRequest) return } var requestBody map[string]interface{} if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { http.Error(w, "Invalid JSON", http.StatusBadRequest) return } model, ok := requestBody["model"].(string) if !ok { http.Error(w, "Missing or invalid 'model' key", http.StatusBadRequest) return } if err := pm.swapModel(model); err != nil { http.Error(w, fmt.Sprintf("unable to swap to model, %s", err.Error()), http.StatusNotFound) return } r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) pm.proxyRequest(w, r) } func (pm *ProxyManager) proxyRequest(w http.ResponseWriter, r *http.Request) { if pm.currentConfig.Proxy == "" { http.Error(w, "No upstream proxy", http.StatusInternalServerError) return } proxyTo := pm.currentConfig.Proxy client := &http.Client{} req, err := http.NewRequest(r.Method, proxyTo+r.URL.String(), r.Body) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } req.Header = r.Header resp, err := client.Do(req) if err != nil { http.Error(w, err.Error(), http.StatusBadGateway) return } defer resp.Body.Close() for k, vv := range resp.Header { for _, v := range vv { w.Header().Add(k, v) } } w.WriteHeader(resp.StatusCode) // faster than io.Copy when streaming buf := make([]byte, 32*1024) for { n, err := resp.Body.Read(buf) if n > 0 { if _, writeErr := w.Write(buf[:n]); writeErr != nil { http.Error(w, writeErr.Error(), http.StatusInternalServerError) return } if flusher, ok := w.(http.Flusher); ok { flusher.Flush() } } if err == io.EOF { break } if err != nil { http.Error(w, err.Error(), http.StatusBadGateway) return } } }