Include metrics from upstream chat requests (#361)

* proxy: refactor metrics recording

- remove metrics_middleware.go as this wrapper is no longer needed. This
  also eliminiates double body parsing for the modelID
- move metrics parsing to be part of MetricsMonitor
- refactor how metrics are recording in ProxyManager
- add MetricsMonitor tests
- improve mem efficiency of processStreamingResponse
- add benchmarks for MetricsMonitor.addMetrics
- proxy: refactor MetricsMonitor to be more safe handling errors
This commit is contained in:
Benson Wong
2025-10-25 17:38:18 -07:00
committed by GitHub
parent d18dc26d01
commit e250e71e59
6 changed files with 939 additions and 289 deletions

View File

@@ -1,184 +0,0 @@
package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
type MetricsRecorder struct {
metricsMonitor *MetricsMonitor
realModelName string
// isStreaming bool
startTime time.Time
}
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
c.Abort()
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
c.Abort()
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
c.Abort()
return
}
writer := &MetricsResponseWriter{
ResponseWriter: c.Writer,
metricsRecorder: &MetricsRecorder{
metricsMonitor: pm.metricsMonitor,
realModelName: realModelName,
startTime: time.Now(),
},
}
c.Writer = writer
c.Next()
// check for streaming response
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
writer.metricsRecorder.processStreamingResponse(writer.body)
} else {
writer.metricsRecorder.processNonStreamingResponse(writer.body)
}
}
}
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
usage := jsonData.Get("usage")
timings := jsonData.Get("timings")
if !usage.Exists() && !timings.Exists() {
return false
}
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(rec.startTime).Milliseconds())
if usage.Exists() {
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
}
rec.metricsMonitor.addMetrics(TokenMetrics{
Timestamp: time.Now(),
Model: rec.realModelName,
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
})
return true
}
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
// Iterate **backwards** through the lines looking for the data payload with
// usage data
lines := bytes.Split(body, []byte("\n"))
for i := len(lines) - 1; i >= 0; i-- {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
return // short circuit if a metric was recorded
}
}
}
}
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
if len(body) == 0 {
return
}
// Parse JSON to extract usage information
if gjson.ValidBytes(body) {
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
}
}
// MetricsResponseWriter captures the entire response for non-streaming
type MetricsResponseWriter struct {
gin.ResponseWriter
body []byte
metricsRecorder *MetricsRecorder
}
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
n, err := w.ResponseWriter.Write(b)
if err != nil {
return n, err
}
w.body = append(w.body, b...)
return n, nil
}
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *MetricsResponseWriter) Header() http.Header {
return w.ResponseWriter.Header()
}

View File

@@ -1,12 +1,18 @@
package proxy package proxy
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync" "sync"
"time" "time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event" "github.com/mostlygeek/llama-swap/event"
"github.com/mostlygeek/llama-swap/proxy/config" "github.com/tidwall/gjson"
) )
// TokenMetrics represents parsed token statistics from llama-server logs // TokenMetrics represents parsed token statistics from llama-server logs
@@ -31,21 +37,18 @@ func (e TokenMetricsEvent) Type() uint32 {
return TokenMetricsEventID // defined in events.go return TokenMetricsEventID // defined in events.go
} }
// MetricsMonitor parses llama-server output for token statistics // metricsMonitor parses llama-server output for token statistics
type MetricsMonitor struct { type metricsMonitor struct {
mu sync.RWMutex mu sync.RWMutex
metrics []TokenMetrics metrics []TokenMetrics
maxMetrics int maxMetrics int
nextID int nextID int
logger *LogMonitor
} }
func NewMetricsMonitor(config *config.Config) *MetricsMonitor { func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
maxMetrics := config.MetricsMaxInMemory mp := &metricsMonitor{
if maxMetrics <= 0 { logger: logger,
maxMetrics = 1000 // Default fallback
}
mp := &MetricsMonitor{
maxMetrics: maxMetrics, maxMetrics: maxMetrics,
} }
@@ -53,7 +56,7 @@ func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
} }
// addMetrics adds a new metric to the collection and publishes an event // addMetrics adds a new metric to the collection and publishes an event
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) { func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
mp.mu.Lock() mp.mu.Lock()
defer mp.mu.Unlock() defer mp.mu.Unlock()
@@ -66,8 +69,8 @@ func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
event.Emit(TokenMetricsEvent{Metrics: metric}) event.Emit(TokenMetricsEvent{Metrics: metric})
} }
// GetMetrics returns a copy of the current metrics // getMetrics returns a copy of the current metrics
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics { func (mp *metricsMonitor) getMetrics() []TokenMetrics {
mp.mu.RLock() mp.mu.RLock()
defer mp.mu.RUnlock() defer mp.mu.RUnlock()
@@ -76,9 +79,189 @@ func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
return result return result
} }
// GetMetricsJSON returns metrics as JSON // getMetricsJSON returns metrics as JSON
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) { func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
mp.mu.RLock() mp.mu.RLock()
defer mp.mu.RUnlock() defer mp.mu.RUnlock()
return json.Marshal(mp.metrics) return json.Marshal(mp.metrics)
} }
// wrapHandler wraps the proxy handler to extract token metrics
// if wrapHandler returns an error it is safe to assume that no
// data was sent to the client
func (mp *metricsMonitor) wrapHandler(
modelID string,
writer gin.ResponseWriter,
request *http.Request,
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
) error {
recorder := newBodyCopier(writer)
if err := next(modelID, recorder, request); err != nil {
return err
}
// after this point we have to assume that data was sent to the client
// and we can only log errors but not send them to clients
if recorder.Status() != http.StatusOK {
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
return nil
}
body := recorder.body.Bytes()
if len(body) == 0 {
mp.logger.Warn("metrics skipped, empty body")
return nil
}
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
} else {
mp.addMetrics(tm)
}
} else {
if gjson.ValidBytes(body) {
if tm, err := parseMetrics(modelID, recorder.StartTime(), gjson.ParseBytes(body)); err != nil {
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
} else {
mp.addMetrics(tm)
}
} else {
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
}
}
return nil
}
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
// Iterate **backwards** through the body looking for the data payload with
// usage data. This avoids allocating a slice of all lines via bytes.Split.
// Start from the end of the body and scan backwards for newlines
pos := len(body)
for pos > 0 {
// Find the previous newline (or start of body)
lineStart := bytes.LastIndexByte(body[:pos], '\n')
if lineStart == -1 {
lineStart = 0
} else {
lineStart++ // Move past the newline
}
line := bytes.TrimSpace(body[lineStart:pos])
pos = lineStart - 1 // Move position before the newline for next iteration
if len(line) == 0 {
continue
}
// SSE payload always follows "data:"
prefix := []byte("data:")
if !bytes.HasPrefix(line, prefix) {
continue
}
data := bytes.TrimSpace(line[len(prefix):])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
// [DONE] line itself contains nothing of interest.
continue
}
if gjson.ValidBytes(data) {
return parseMetrics(modelID, start, gjson.ParseBytes(data))
}
}
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
}
func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (TokenMetrics, error) {
usage := jsonData.Get("usage")
timings := jsonData.Get("timings")
if !usage.Exists() && !timings.Exists() {
return TokenMetrics{}, fmt.Errorf("no usage or timings data found")
}
// default values
cachedTokens := -1 // unknown or missing data
outputTokens := 0
inputTokens := 0
// timings data
tokensPerSecond := -1.0
promptPerSecond := -1.0
durationMs := int(time.Since(start).Milliseconds())
if usage.Exists() {
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
}
// use llama-server's timing data for tok/sec and duration as it is more accurate
if timings.Exists() {
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
cachedTokens = int(cachedValue.Int())
}
}
return TokenMetrics{
Timestamp: time.Now(),
Model: modelID,
CachedTokens: cachedTokens,
InputTokens: inputTokens,
OutputTokens: outputTokens,
PromptPerSecond: promptPerSecond,
TokensPerSecond: tokensPerSecond,
DurationMs: durationMs,
}, nil
}
// responseBodyCopier records the response body and writes to the original response writer
// while also capturing it in a buffer for later processing
type responseBodyCopier struct {
gin.ResponseWriter
body *bytes.Buffer
tee io.Writer
start time.Time
}
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
bodyBuffer := &bytes.Buffer{}
return &responseBodyCopier{
ResponseWriter: w,
body: bodyBuffer,
tee: io.MultiWriter(w, bodyBuffer),
}
}
func (w *responseBodyCopier) Write(b []byte) (int, error) {
if w.start.IsZero() {
w.start = time.Now()
}
// Single write operation that writes to both the response and buffer
return w.tee.Write(b)
}
func (w *responseBodyCopier) WriteHeader(statusCode int) {
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *responseBodyCopier) Header() http.Header {
return w.ResponseWriter.Header()
}
func (w *responseBodyCopier) StartTime() time.Time {
return w.start
}

View File

@@ -0,0 +1,693 @@
package proxy
import (
"encoding/json"
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/mostlygeek/llama-swap/event"
"github.com/stretchr/testify/assert"
)
func TestMetricsMonitor_AddMetrics(t *testing.T) {
t.Run("adds metrics and assigns ID", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
metric := TokenMetrics{
Model: "test-model",
InputTokens: 100,
OutputTokens: 50,
}
mm.addMetrics(metric)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 0, metrics[0].ID)
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("increments ID for each metric", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{Model: "model"})
}
metrics := mm.getMetrics()
assert.Equal(t, 5, len(metrics))
for i := 0; i < 5; i++ {
assert.Equal(t, i, metrics[i].ID)
}
})
t.Run("respects max metrics limit", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 3)
// Add 5 metrics
for i := 0; i < 5; i++ {
mm.addMetrics(TokenMetrics{
Model: "model",
InputTokens: i,
})
}
metrics := mm.getMetrics()
assert.Equal(t, 3, len(metrics))
// Should keep the last 3 metrics (IDs 2, 3, 4)
assert.Equal(t, 2, metrics[0].ID)
assert.Equal(t, 3, metrics[1].ID)
assert.Equal(t, 4, metrics[2].ID)
})
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
receivedEvent := make(chan TokenMetricsEvent, 1)
cancel := event.On(func(e TokenMetricsEvent) {
receivedEvent <- e
})
defer cancel()
metric := TokenMetrics{
Model: "test-model",
InputTokens: 100,
OutputTokens: 50,
}
mm.addMetrics(metric)
select {
case evt := <-receivedEvent:
assert.Equal(t, 0, evt.Metrics.ID)
assert.Equal(t, "test-model", evt.Metrics.Model)
assert.Equal(t, 100, evt.Metrics.InputTokens)
assert.Equal(t, 50, evt.Metrics.OutputTokens)
case <-time.After(1 * time.Second):
t.Fatal("timeout waiting for event")
}
})
}
func TestMetricsMonitor_GetMetrics(t *testing.T) {
t.Run("returns empty slice when no metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
metrics := mm.getMetrics()
assert.NotNil(t, metrics)
assert.Equal(t, 0, len(metrics))
})
t.Run("returns copy of metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm.addMetrics(TokenMetrics{Model: "model1"})
mm.addMetrics(TokenMetrics{Model: "model2"})
metrics1 := mm.getMetrics()
metrics2 := mm.getMetrics()
// Verify we got copies
assert.Equal(t, 2, len(metrics1))
assert.Equal(t, 2, len(metrics2))
// Modify the returned slice shouldn't affect the original
metrics1[0].Model = "modified"
metrics3 := mm.getMetrics()
assert.Equal(t, "model1", metrics3[0].Model)
})
}
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err)
assert.NotNil(t, jsonData)
var metrics []TokenMetrics
err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err)
assert.Equal(t, 0, len(metrics))
})
t.Run("returns valid JSON with metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
mm.addMetrics(TokenMetrics{
Model: "model1",
InputTokens: 100,
OutputTokens: 50,
TokensPerSecond: 25.5,
})
mm.addMetrics(TokenMetrics{
Model: "model2",
InputTokens: 200,
OutputTokens: 100,
TokensPerSecond: 30.0,
})
jsonData, err := mm.getMetricsJSON()
assert.NoError(t, err)
var metrics []TokenMetrics
err = json.Unmarshal(jsonData, &metrics)
assert.NoError(t, err)
assert.Equal(t, 2, len(metrics))
assert.Equal(t, "model1", metrics[0].Model)
assert.Equal(t, "model2", metrics[1].Model)
})
}
func TestMetricsMonitor_WrapHandler(t *testing.T) {
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"usage": {
"prompt_tokens": 100,
"completion_tokens": 50
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("successful request with timings data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0,
"cache_n": 20
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
assert.Equal(t, 20, metrics[0].CachedTokens)
assert.Equal(t, 150.5, metrics[0].PromptPerSecond)
assert.Equal(t, 25.5, metrics[0].TokensPerSecond)
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
})
t.Run("streaming request with SSE format", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Note: SSE format requires proper line breaks - each data line followed by blank line
responseBody := `data: {"choices":[{"text":"Hello"}]}
data: {"choices":[{"text":" World"}]}
data: {"usage":{"prompt_tokens":10,"completion_tokens":20},"timings":{"prompt_n":10,"predicted_n":20,"prompt_per_second":100.0,"predicted_per_second":50.0,"prompt_ms":100.0,"predicted_ms":400.0}}
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, "test-model", metrics[0].Model)
// When timings data is present, it takes precedence
assert.Equal(t, 10, metrics[0].InputTokens)
assert.Equal(t, 20, metrics[0].OutputTokens)
})
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("error"))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("empty response body does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.WriteHeader(http.StatusOK)
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte("not valid json"))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("next handler error is propagated", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
expectedErr := assert.AnError
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
return expectedErr
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.Equal(t, expectedErr, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{"result": "ok"}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
}
func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
t.Run("captures response body", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
testData := []byte("test response body")
n, err := copier.Write(testData)
assert.NoError(t, err)
assert.Equal(t, len(testData), n)
assert.Equal(t, testData, copier.body.Bytes())
assert.Equal(t, string(testData), rec.Body.String())
})
t.Run("sets start time on first write", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
assert.True(t, copier.StartTime().IsZero())
copier.Write([]byte("test"))
assert.False(t, copier.StartTime().IsZero())
})
t.Run("preserves headers", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
copier.Header().Set("X-Test", "value")
assert.Equal(t, "value", rec.Header().Get("X-Test"))
})
t.Run("preserves status code", func(t *testing.T) {
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
copier := newBodyCopier(ginCtx.Writer)
copier.WriteHeader(http.StatusCreated)
// Gin's ResponseWriter tracks status internally
assert.Equal(t, http.StatusCreated, copier.Status())
})
}
func TestMetricsMonitor_Concurrent(t *testing.T) {
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 1000)
var wg sync.WaitGroup
numGoroutines := 10
metricsPerGoroutine := 100
for i := 0; i < numGoroutines; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
for j := 0; j < metricsPerGoroutine; j++ {
mm.addMetrics(TokenMetrics{
Model: "test-model",
InputTokens: id*1000 + j,
OutputTokens: j,
})
}
}(i)
}
wg.Wait()
metrics := mm.getMetrics()
assert.Equal(t, numGoroutines*metricsPerGoroutine, len(metrics))
})
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 100)
done := make(chan bool)
// Writer goroutine
go func() {
for i := 0; i < 50; i++ {
mm.addMetrics(TokenMetrics{Model: "test-model"})
time.Sleep(1 * time.Millisecond)
}
done <- true
}()
// Multiple reader goroutines
var wg sync.WaitGroup
for i := 0; i < 5; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 20; j++ {
_ = mm.getMetrics()
_, _ = mm.getMetricsJSON()
time.Sleep(2 * time.Millisecond)
}
}()
}
<-done
wg.Wait()
// Final check
metrics := mm.getMetrics()
assert.Equal(t, 50, len(metrics))
})
}
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
t.Run("prefers timings over usage data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Timings should take precedence over usage
responseBody := `{
"usage": {
"prompt_tokens": 50,
"completion_tokens": 25
},
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
// Should use timings values, not usage values
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("handles missing cache_n in timings", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `{
"timings": {
"prompt_n": 100,
"predicted_n": 50,
"prompt_per_second": 150.5,
"predicted_per_second": 25.5,
"prompt_ms": 500.0,
"predicted_ms": 1500.0
}
}`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present
})
}
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
// Metrics should be found in the last data line before [DONE]
responseBody := `data: {"choices":[{"text":"First"}]}
data: {"choices":[{"text":"Second"}]}
data: {"usage":{"prompt_tokens":100,"completion_tokens":50}}
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 1, len(metrics))
assert.Equal(t, 100, metrics[0].InputTokens)
assert.Equal(t, 50, metrics[0].OutputTokens)
})
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := `data: not json
data: [DONE]
`
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
assert.NoError(t, err) // Errors after response is sent are logged, not returned
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
t.Run("handles empty streaming response", func(t *testing.T) {
mm := newMetricsMonitor(testLogger, 10)
responseBody := ``
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
w.Write([]byte(responseBody))
return nil
}
req := httptest.NewRequest("POST", "/test", nil)
rec := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(rec)
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
// Empty body should not trigger WrapHandler processing
assert.NoError(t, err)
metrics := mm.getMetrics()
assert.Equal(t, 0, len(metrics))
})
}
// Benchmark tests
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
mm := newMetricsMonitor(testLogger, 1000)
metric := TokenMetrics{
Model: "test-model",
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
DurationMs: 5000,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
mm.addMetrics(metric)
}
}
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
// Test performance with a smaller buffer where wrapping occurs more frequently
mm := newMetricsMonitor(testLogger, 100)
metric := TokenMetrics{
Model: "test-model",
CachedTokens: 100,
InputTokens: 500,
OutputTokens: 250,
PromptPerSecond: 1200.5,
TokensPerSecond: 45.8,
DurationMs: 5000,
Timestamp: time.Now(),
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
mm.addMetrics(metric)
}
}

View File

@@ -36,7 +36,7 @@ type ProxyManager struct {
upstreamLogger *LogMonitor upstreamLogger *LogMonitor
muxLogger *LogMonitor muxLogger *LogMonitor
metricsMonitor *MetricsMonitor metricsMonitor *metricsMonitor
processGroups map[string]*ProcessGroup processGroups map[string]*ProcessGroup
@@ -75,6 +75,13 @@ func New(config config.Config) *ProxyManager {
shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
var maxMetrics int
if config.MetricsMaxInMemory <= 0 {
maxMetrics = 1000 // Default fallback
} else {
maxMetrics = config.MetricsMaxInMemory
}
pm := &ProxyManager{ pm := &ProxyManager{
config: config, config: config,
ginEngine: gin.New(), ginEngine: gin.New(),
@@ -83,7 +90,7 @@ func New(config config.Config) *ProxyManager {
muxLogger: stdoutLogger, muxLogger: stdoutLogger,
upstreamLogger: upstreamLogger, upstreamLogger: upstreamLogger,
metricsMonitor: NewMetricsMonitor(&config), metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
processGroups: make(map[string]*ProcessGroup), processGroups: make(map[string]*ProcessGroup),
@@ -193,27 +200,25 @@ func (pm *ProxyManager) setupGinEngine() {
c.Next() c.Next()
}) })
mm := MetricsMiddleware(pm)
// Set up routes using the Gin engine // Set up routes using the Gin engine
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
// Support legacy /v1/completions api, see issue #12 // Support legacy /v1/completions api, see issue #12
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
// Support embeddings and reranking // Support embeddings and reranking
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
// llama-server's /reranking endpoint + aliases // llama-server's /reranking endpoint + aliases
pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/reranking", pm.proxyOAIHandler)
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/rerank", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/v1/reranking", pm.proxyOAIHandler)
// llama-server's /infill endpoint for code infilling // llama-server's /infill endpoint for code infilling
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/infill", pm.proxyOAIHandler)
// llama-server's /completion endpoint // llama-server's /completion endpoint
pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler) pm.ginEngine.POST("/completion", pm.proxyOAIHandler)
// Support audio/speech endpoint // Support audio/speech endpoint
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
@@ -474,8 +479,23 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
} }
// rewrite the path // rewrite the path
originalPath := c.Request.URL.Path
c.Request.URL.Path = remainingPath c.Request.URL.Path = remainingPath
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
// attempt to record metrics if it is a POST request
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
return
}
} else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
return
}
}
} }
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
@@ -535,12 +555,20 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes))) c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
c.Request.ContentLength = int64(len(bodyBytes)) c.Request.ContentLength = int64(len(bodyBytes))
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
return
}
} else {
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
return return
} }
} }
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// Parse multipart form // Parse multipart form

View File

@@ -180,7 +180,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
sendLogData("proxy", pm.proxyLogger.GetHistory()) sendLogData("proxy", pm.proxyLogger.GetHistory())
sendLogData("upstream", pm.upstreamLogger.GetHistory()) sendLogData("upstream", pm.upstreamLogger.GetHistory())
sendModels() sendModels()
sendMetrics(pm.metricsMonitor.GetMetrics()) sendMetrics(pm.metricsMonitor.getMetrics())
for { for {
select { select {
@@ -198,7 +198,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
} }
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) { func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
jsonData, err := pm.metricsMonitor.GetMetricsJSON() jsonData, err := pm.metricsMonitor.getMetricsJSON()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
return return

View File

@@ -911,76 +911,6 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
// t.Logf("%v", response) // t.Logf("%v", response)
} }
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Make a non-streaming request
reqBody := `{"model":"model1", "stream": false}`
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Check that metrics were recorded
metrics := proxy.metricsMonitor.GetMetrics()
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
return
}
// Verify the last metric has the correct model
lastMetric := metrics[len(metrics)-1]
assert.Equal(t, "model1", lastMetric.Model)
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
}
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15,
Models: map[string]config.ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses(StopWaitForInflightRequest)
// Make a streaming request
reqBody := `{"model":"model1", "stream": true}`
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
w := CreateTestResponseRecorder()
proxy.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
// Check that metrics were recorded
metrics := proxy.metricsMonitor.GetMetrics()
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
return
}
// Verify the last metric has the correct model
lastMetric := metrics[len(metrics)-1]
assert.Equal(t, "model1", lastMetric.Model)
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
}
func TestProxyManager_HealthEndpoint(t *testing.T) { func TestProxyManager_HealthEndpoint(t *testing.T) {
config := config.AddDefaultGroupToConfig(config.Config{ config := config.AddDefaultGroupToConfig(config.Config{
HealthCheckTimeout: 15, HealthCheckTimeout: 15,