Add support for llama-server's /infill endpoint and metrics gathering on the Activities page.
180 lines
4.5 KiB
Go
180 lines
4.5 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/tidwall/gjson"
|
|
)
|
|
|
|
type MetricsRecorder struct {
|
|
metricsMonitor *MetricsMonitor
|
|
realModelName string
|
|
// isStreaming bool
|
|
startTime time.Time
|
|
}
|
|
|
|
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
|
|
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
if err != nil {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
|
c.Abort()
|
|
return
|
|
}
|
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
|
|
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
|
if requestedModel == "" {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
realModelName, found := pm.config.RealModelName(requestedModel)
|
|
if !found {
|
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
writer := &MetricsResponseWriter{
|
|
ResponseWriter: c.Writer,
|
|
metricsRecorder: &MetricsRecorder{
|
|
metricsMonitor: pm.metricsMonitor,
|
|
realModelName: realModelName,
|
|
startTime: time.Now(),
|
|
},
|
|
}
|
|
c.Writer = writer
|
|
c.Next()
|
|
|
|
// check for streaming response
|
|
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
|
|
writer.metricsRecorder.processStreamingResponse(writer.body)
|
|
} else {
|
|
writer.metricsRecorder.processNonStreamingResponse(writer.body)
|
|
}
|
|
|
|
}
|
|
}
|
|
|
|
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
|
|
usage := jsonData.Get("usage")
|
|
timings := jsonData.Get("timings")
|
|
if !usage.Exists() && !timings.Exists() {
|
|
return false
|
|
}
|
|
|
|
// default values
|
|
outputTokens := 0
|
|
inputTokens := 0
|
|
|
|
// timings data
|
|
tokensPerSecond := -1.0
|
|
promptPerSecond := -1.0
|
|
durationMs := int(time.Since(rec.startTime).Milliseconds())
|
|
|
|
if usage.Exists() {
|
|
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
|
|
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
|
|
}
|
|
|
|
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
|
if timings.Exists() {
|
|
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
|
|
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
|
|
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
|
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
|
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
|
}
|
|
|
|
rec.metricsMonitor.addMetrics(TokenMetrics{
|
|
Timestamp: time.Now(),
|
|
Model: rec.realModelName,
|
|
InputTokens: inputTokens,
|
|
OutputTokens: outputTokens,
|
|
PromptPerSecond: promptPerSecond,
|
|
TokensPerSecond: tokensPerSecond,
|
|
DurationMs: durationMs,
|
|
})
|
|
|
|
return true
|
|
}
|
|
|
|
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
|
|
// Iterate **backwards** through the lines looking for the data payload with
|
|
// usage data
|
|
lines := bytes.Split(body, []byte("\n"))
|
|
|
|
for i := len(lines) - 1; i >= 0; i-- {
|
|
line := bytes.TrimSpace(lines[i])
|
|
if len(line) == 0 {
|
|
continue
|
|
}
|
|
|
|
// SSE payload always follows "data:"
|
|
prefix := []byte("data:")
|
|
if !bytes.HasPrefix(line, prefix) {
|
|
continue
|
|
}
|
|
data := bytes.TrimSpace(line[len(prefix):])
|
|
|
|
if len(data) == 0 {
|
|
continue
|
|
}
|
|
|
|
if bytes.Equal(data, []byte("[DONE]")) {
|
|
// [DONE] line itself contains nothing of interest.
|
|
continue
|
|
}
|
|
|
|
if gjson.ValidBytes(data) {
|
|
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
|
|
return // short circuit if a metric was recorded
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
|
|
if len(body) == 0 {
|
|
return
|
|
}
|
|
|
|
// Parse JSON to extract usage information
|
|
if gjson.ValidBytes(body) {
|
|
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
|
|
}
|
|
}
|
|
|
|
// MetricsResponseWriter captures the entire response for non-streaming
|
|
type MetricsResponseWriter struct {
|
|
gin.ResponseWriter
|
|
body []byte
|
|
metricsRecorder *MetricsRecorder
|
|
}
|
|
|
|
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
|
|
n, err := w.ResponseWriter.Write(b)
|
|
if err != nil {
|
|
return n, err
|
|
}
|
|
w.body = append(w.body, b...)
|
|
return n, nil
|
|
}
|
|
|
|
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
|
|
w.ResponseWriter.WriteHeader(statusCode)
|
|
}
|
|
|
|
func (w *MetricsResponseWriter) Header() http.Header {
|
|
return w.ResponseWriter.Header()
|
|
}
|