Include metrics from upstream chat requests (#361)
* proxy: refactor metrics recording - remove metrics_middleware.go as this wrapper is no longer needed. This also eliminiates double body parsing for the modelID - move metrics parsing to be part of MetricsMonitor - refactor how metrics are recording in ProxyManager - add MetricsMonitor tests - improve mem efficiency of processStreamingResponse - add benchmarks for MetricsMonitor.addMetrics - proxy: refactor MetricsMonitor to be more safe handling errors
This commit is contained in:
@@ -1,184 +0,0 @@
|
|||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MetricsRecorder struct {
|
|
||||||
metricsMonitor *MetricsMonitor
|
|
||||||
realModelName string
|
|
||||||
// isStreaming bool
|
|
||||||
startTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
|
|
||||||
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
|
||||||
return func(c *gin.Context) {
|
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
|
||||||
if err != nil {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
|
||||||
|
|
||||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
|
||||||
if requestedModel == "" {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
|
||||||
if !found {
|
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
|
||||||
c.Abort()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
writer := &MetricsResponseWriter{
|
|
||||||
ResponseWriter: c.Writer,
|
|
||||||
metricsRecorder: &MetricsRecorder{
|
|
||||||
metricsMonitor: pm.metricsMonitor,
|
|
||||||
realModelName: realModelName,
|
|
||||||
startTime: time.Now(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
c.Writer = writer
|
|
||||||
c.Next()
|
|
||||||
|
|
||||||
// check for streaming response
|
|
||||||
if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") {
|
|
||||||
writer.metricsRecorder.processStreamingResponse(writer.body)
|
|
||||||
} else {
|
|
||||||
writer.metricsRecorder.processNonStreamingResponse(writer.body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool {
|
|
||||||
usage := jsonData.Get("usage")
|
|
||||||
timings := jsonData.Get("timings")
|
|
||||||
if !usage.Exists() && !timings.Exists() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// default values
|
|
||||||
cachedTokens := -1 // unknown or missing data
|
|
||||||
outputTokens := 0
|
|
||||||
inputTokens := 0
|
|
||||||
|
|
||||||
// timings data
|
|
||||||
tokensPerSecond := -1.0
|
|
||||||
promptPerSecond := -1.0
|
|
||||||
durationMs := int(time.Since(rec.startTime).Milliseconds())
|
|
||||||
|
|
||||||
if usage.Exists() {
|
|
||||||
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
|
|
||||||
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
|
|
||||||
}
|
|
||||||
|
|
||||||
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
|
||||||
if timings.Exists() {
|
|
||||||
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
|
|
||||||
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
|
|
||||||
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
|
||||||
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
|
||||||
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
|
||||||
|
|
||||||
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
|
|
||||||
cachedTokens = int(cachedValue.Int())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
rec.metricsMonitor.addMetrics(TokenMetrics{
|
|
||||||
Timestamp: time.Now(),
|
|
||||||
Model: rec.realModelName,
|
|
||||||
CachedTokens: cachedTokens,
|
|
||||||
InputTokens: inputTokens,
|
|
||||||
OutputTokens: outputTokens,
|
|
||||||
PromptPerSecond: promptPerSecond,
|
|
||||||
TokensPerSecond: tokensPerSecond,
|
|
||||||
DurationMs: durationMs,
|
|
||||||
})
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
|
|
||||||
// Iterate **backwards** through the lines looking for the data payload with
|
|
||||||
// usage data
|
|
||||||
lines := bytes.Split(body, []byte("\n"))
|
|
||||||
|
|
||||||
for i := len(lines) - 1; i >= 0; i-- {
|
|
||||||
line := bytes.TrimSpace(lines[i])
|
|
||||||
if len(line) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// SSE payload always follows "data:"
|
|
||||||
prefix := []byte("data:")
|
|
||||||
if !bytes.HasPrefix(line, prefix) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
data := bytes.TrimSpace(line[len(prefix):])
|
|
||||||
|
|
||||||
if len(data) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if bytes.Equal(data, []byte("[DONE]")) {
|
|
||||||
// [DONE] line itself contains nothing of interest.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if gjson.ValidBytes(data) {
|
|
||||||
if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) {
|
|
||||||
return // short circuit if a metric was recorded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
|
|
||||||
if len(body) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse JSON to extract usage information
|
|
||||||
if gjson.ValidBytes(body) {
|
|
||||||
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MetricsResponseWriter captures the entire response for non-streaming
|
|
||||||
type MetricsResponseWriter struct {
|
|
||||||
gin.ResponseWriter
|
|
||||||
body []byte
|
|
||||||
metricsRecorder *MetricsRecorder
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
|
|
||||||
n, err := w.ResponseWriter.Write(b)
|
|
||||||
if err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
w.body = append(w.body, b...)
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
|
|
||||||
w.ResponseWriter.WriteHeader(statusCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *MetricsResponseWriter) Header() http.Header {
|
|
||||||
return w.ResponseWriter.Header()
|
|
||||||
}
|
|
||||||
@@ -1,12 +1,18 @@
|
|||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/mostlygeek/llama-swap/event"
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
"github.com/mostlygeek/llama-swap/proxy/config"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||||
@@ -31,21 +37,18 @@ func (e TokenMetricsEvent) Type() uint32 {
|
|||||||
return TokenMetricsEventID // defined in events.go
|
return TokenMetricsEventID // defined in events.go
|
||||||
}
|
}
|
||||||
|
|
||||||
// MetricsMonitor parses llama-server output for token statistics
|
// metricsMonitor parses llama-server output for token statistics
|
||||||
type MetricsMonitor struct {
|
type metricsMonitor struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
metrics []TokenMetrics
|
metrics []TokenMetrics
|
||||||
maxMetrics int
|
maxMetrics int
|
||||||
nextID int
|
nextID int
|
||||||
|
logger *LogMonitor
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
|
func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor {
|
||||||
maxMetrics := config.MetricsMaxInMemory
|
mp := &metricsMonitor{
|
||||||
if maxMetrics <= 0 {
|
logger: logger,
|
||||||
maxMetrics = 1000 // Default fallback
|
|
||||||
}
|
|
||||||
|
|
||||||
mp := &MetricsMonitor{
|
|
||||||
maxMetrics: maxMetrics,
|
maxMetrics: maxMetrics,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -53,7 +56,7 @@ func NewMetricsMonitor(config *config.Config) *MetricsMonitor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addMetrics adds a new metric to the collection and publishes an event
|
// addMetrics adds a new metric to the collection and publishes an event
|
||||||
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
func (mp *metricsMonitor) addMetrics(metric TokenMetrics) {
|
||||||
mp.mu.Lock()
|
mp.mu.Lock()
|
||||||
defer mp.mu.Unlock()
|
defer mp.mu.Unlock()
|
||||||
|
|
||||||
@@ -66,8 +69,8 @@ func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
|||||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMetrics returns a copy of the current metrics
|
// getMetrics returns a copy of the current metrics
|
||||||
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
func (mp *metricsMonitor) getMetrics() []TokenMetrics {
|
||||||
mp.mu.RLock()
|
mp.mu.RLock()
|
||||||
defer mp.mu.RUnlock()
|
defer mp.mu.RUnlock()
|
||||||
|
|
||||||
@@ -76,9 +79,189 @@ func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMetricsJSON returns metrics as JSON
|
// getMetricsJSON returns metrics as JSON
|
||||||
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
|
func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) {
|
||||||
mp.mu.RLock()
|
mp.mu.RLock()
|
||||||
defer mp.mu.RUnlock()
|
defer mp.mu.RUnlock()
|
||||||
return json.Marshal(mp.metrics)
|
return json.Marshal(mp.metrics)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// wrapHandler wraps the proxy handler to extract token metrics
|
||||||
|
// if wrapHandler returns an error it is safe to assume that no
|
||||||
|
// data was sent to the client
|
||||||
|
func (mp *metricsMonitor) wrapHandler(
|
||||||
|
modelID string,
|
||||||
|
writer gin.ResponseWriter,
|
||||||
|
request *http.Request,
|
||||||
|
next func(modelID string, w http.ResponseWriter, r *http.Request) error,
|
||||||
|
) error {
|
||||||
|
recorder := newBodyCopier(writer)
|
||||||
|
if err := next(modelID, recorder, request); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// after this point we have to assume that data was sent to the client
|
||||||
|
// and we can only log errors but not send them to clients
|
||||||
|
|
||||||
|
if recorder.Status() != http.StatusOK {
|
||||||
|
mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
body := recorder.body.Bytes()
|
||||||
|
if len(body) == 0 {
|
||||||
|
mp.logger.Warn("metrics skipped, empty body")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") {
|
||||||
|
if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil {
|
||||||
|
mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path)
|
||||||
|
} else {
|
||||||
|
mp.addMetrics(tm)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if gjson.ValidBytes(body) {
|
||||||
|
if tm, err := parseMetrics(modelID, recorder.StartTime(), gjson.ParseBytes(body)); err != nil {
|
||||||
|
mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path)
|
||||||
|
} else {
|
||||||
|
mp.addMetrics(tm)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) {
|
||||||
|
// Iterate **backwards** through the body looking for the data payload with
|
||||||
|
// usage data. This avoids allocating a slice of all lines via bytes.Split.
|
||||||
|
|
||||||
|
// Start from the end of the body and scan backwards for newlines
|
||||||
|
pos := len(body)
|
||||||
|
for pos > 0 {
|
||||||
|
// Find the previous newline (or start of body)
|
||||||
|
lineStart := bytes.LastIndexByte(body[:pos], '\n')
|
||||||
|
if lineStart == -1 {
|
||||||
|
lineStart = 0
|
||||||
|
} else {
|
||||||
|
lineStart++ // Move past the newline
|
||||||
|
}
|
||||||
|
|
||||||
|
line := bytes.TrimSpace(body[lineStart:pos])
|
||||||
|
pos = lineStart - 1 // Move position before the newline for next iteration
|
||||||
|
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE payload always follows "data:"
|
||||||
|
prefix := []byte("data:")
|
||||||
|
if !bytes.HasPrefix(line, prefix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(line[len(prefix):])
|
||||||
|
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
// [DONE] line itself contains nothing of interest.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if gjson.ValidBytes(data) {
|
||||||
|
return parseMetrics(modelID, start, gjson.ParseBytes(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (TokenMetrics, error) {
|
||||||
|
usage := jsonData.Get("usage")
|
||||||
|
timings := jsonData.Get("timings")
|
||||||
|
if !usage.Exists() && !timings.Exists() {
|
||||||
|
return TokenMetrics{}, fmt.Errorf("no usage or timings data found")
|
||||||
|
}
|
||||||
|
// default values
|
||||||
|
cachedTokens := -1 // unknown or missing data
|
||||||
|
outputTokens := 0
|
||||||
|
inputTokens := 0
|
||||||
|
|
||||||
|
// timings data
|
||||||
|
tokensPerSecond := -1.0
|
||||||
|
promptPerSecond := -1.0
|
||||||
|
durationMs := int(time.Since(start).Milliseconds())
|
||||||
|
|
||||||
|
if usage.Exists() {
|
||||||
|
outputTokens = int(jsonData.Get("usage.completion_tokens").Int())
|
||||||
|
inputTokens = int(jsonData.Get("usage.prompt_tokens").Int())
|
||||||
|
}
|
||||||
|
|
||||||
|
// use llama-server's timing data for tok/sec and duration as it is more accurate
|
||||||
|
if timings.Exists() {
|
||||||
|
inputTokens = int(jsonData.Get("timings.prompt_n").Int())
|
||||||
|
outputTokens = int(jsonData.Get("timings.predicted_n").Int())
|
||||||
|
promptPerSecond = jsonData.Get("timings.prompt_per_second").Float()
|
||||||
|
tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float()
|
||||||
|
durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float())
|
||||||
|
|
||||||
|
if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() {
|
||||||
|
cachedTokens = int(cachedValue.Int())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return TokenMetrics{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Model: modelID,
|
||||||
|
CachedTokens: cachedTokens,
|
||||||
|
InputTokens: inputTokens,
|
||||||
|
OutputTokens: outputTokens,
|
||||||
|
PromptPerSecond: promptPerSecond,
|
||||||
|
TokensPerSecond: tokensPerSecond,
|
||||||
|
DurationMs: durationMs,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseBodyCopier records the response body and writes to the original response writer
|
||||||
|
// while also capturing it in a buffer for later processing
|
||||||
|
type responseBodyCopier struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
body *bytes.Buffer
|
||||||
|
tee io.Writer
|
||||||
|
start time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier {
|
||||||
|
bodyBuffer := &bytes.Buffer{}
|
||||||
|
return &responseBodyCopier{
|
||||||
|
ResponseWriter: w,
|
||||||
|
body: bodyBuffer,
|
||||||
|
tee: io.MultiWriter(w, bodyBuffer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) Write(b []byte) (int, error) {
|
||||||
|
if w.start.IsZero() {
|
||||||
|
w.start = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Single write operation that writes to both the response and buffer
|
||||||
|
return w.tee.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) WriteHeader(statusCode int) {
|
||||||
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) Header() http.Header {
|
||||||
|
return w.ResponseWriter.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseBodyCopier) StartTime() time.Time {
|
||||||
|
return w.start
|
||||||
|
}
|
||||||
|
|||||||
693
proxy/metrics_monitor_test.go
Normal file
693
proxy/metrics_monitor_test.go
Normal file
@@ -0,0 +1,693 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/mostlygeek/llama-swap/event"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMetricsMonitor_AddMetrics(t *testing.T) {
|
||||||
|
t.Run("adds metrics and assigns ID", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
metric := TokenMetrics{
|
||||||
|
Model: "test-model",
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
mm.addMetrics(metric)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, 0, metrics[0].ID)
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("increments ID for each metric", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
mm.addMetrics(TokenMetrics{Model: "model"})
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 5, len(metrics))
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
assert.Equal(t, i, metrics[i].ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("respects max metrics limit", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 3)
|
||||||
|
|
||||||
|
// Add 5 metrics
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
mm.addMetrics(TokenMetrics{
|
||||||
|
Model: "model",
|
||||||
|
InputTokens: i,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 3, len(metrics))
|
||||||
|
|
||||||
|
// Should keep the last 3 metrics (IDs 2, 3, 4)
|
||||||
|
assert.Equal(t, 2, metrics[0].ID)
|
||||||
|
assert.Equal(t, 3, metrics[1].ID)
|
||||||
|
assert.Equal(t, 4, metrics[2].ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("emits TokenMetricsEvent", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
receivedEvent := make(chan TokenMetricsEvent, 1)
|
||||||
|
cancel := event.On(func(e TokenMetricsEvent) {
|
||||||
|
receivedEvent <- e
|
||||||
|
})
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
metric := TokenMetrics{
|
||||||
|
Model: "test-model",
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
}
|
||||||
|
|
||||||
|
mm.addMetrics(metric)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case evt := <-receivedEvent:
|
||||||
|
assert.Equal(t, 0, evt.Metrics.ID)
|
||||||
|
assert.Equal(t, "test-model", evt.Metrics.Model)
|
||||||
|
assert.Equal(t, 100, evt.Metrics.InputTokens)
|
||||||
|
assert.Equal(t, 50, evt.Metrics.OutputTokens)
|
||||||
|
case <-time.After(1 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for event")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_GetMetrics(t *testing.T) {
|
||||||
|
t.Run("returns empty slice when no metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.NotNil(t, metrics)
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns copy of metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
mm.addMetrics(TokenMetrics{Model: "model1"})
|
||||||
|
mm.addMetrics(TokenMetrics{Model: "model2"})
|
||||||
|
|
||||||
|
metrics1 := mm.getMetrics()
|
||||||
|
metrics2 := mm.getMetrics()
|
||||||
|
|
||||||
|
// Verify we got copies
|
||||||
|
assert.Equal(t, 2, len(metrics1))
|
||||||
|
assert.Equal(t, 2, len(metrics2))
|
||||||
|
|
||||||
|
// Modify the returned slice shouldn't affect the original
|
||||||
|
metrics1[0].Model = "modified"
|
||||||
|
metrics3 := mm.getMetrics()
|
||||||
|
assert.Equal(t, "model1", metrics3[0].Model)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_GetMetricsJSON(t *testing.T) {
|
||||||
|
t.Run("returns valid JSON for empty metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
jsonData, err := mm.getMetricsJSON()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotNil(t, jsonData)
|
||||||
|
|
||||||
|
var metrics []TokenMetrics
|
||||||
|
err = json.Unmarshal(jsonData, &metrics)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("returns valid JSON with metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
mm.addMetrics(TokenMetrics{
|
||||||
|
Model: "model1",
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
TokensPerSecond: 25.5,
|
||||||
|
})
|
||||||
|
mm.addMetrics(TokenMetrics{
|
||||||
|
Model: "model2",
|
||||||
|
InputTokens: 200,
|
||||||
|
OutputTokens: 100,
|
||||||
|
TokensPerSecond: 30.0,
|
||||||
|
})
|
||||||
|
|
||||||
|
jsonData, err := mm.getMetricsJSON()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var metrics []TokenMetrics
|
||||||
|
err = json.Unmarshal(jsonData, &metrics)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, 2, len(metrics))
|
||||||
|
assert.Equal(t, "model1", metrics[0].Model)
|
||||||
|
assert.Equal(t, "model2", metrics[1].Model)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_WrapHandler(t *testing.T) {
|
||||||
|
t.Run("successful non-streaming request with usage data", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 100,
|
||||||
|
"completion_tokens": 50
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("successful request with timings data", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"timings": {
|
||||||
|
"prompt_n": 100,
|
||||||
|
"predicted_n": 50,
|
||||||
|
"prompt_per_second": 150.5,
|
||||||
|
"predicted_per_second": 25.5,
|
||||||
|
"prompt_ms": 500.0,
|
||||||
|
"predicted_ms": 1500.0,
|
||||||
|
"cache_n": 20
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
assert.Equal(t, 20, metrics[0].CachedTokens)
|
||||||
|
assert.Equal(t, 150.5, metrics[0].PromptPerSecond)
|
||||||
|
assert.Equal(t, 25.5, metrics[0].TokensPerSecond)
|
||||||
|
assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("streaming request with SSE format", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
// Note: SSE format requires proper line breaks - each data line followed by blank line
|
||||||
|
responseBody := `data: {"choices":[{"text":"Hello"}]}
|
||||||
|
|
||||||
|
data: {"choices":[{"text":" World"}]}
|
||||||
|
|
||||||
|
data: {"usage":{"prompt_tokens":10,"completion_tokens":20},"timings":{"prompt_n":10,"predicted_n":20,"prompt_per_second":100.0,"predicted_per_second":50.0,"prompt_ms":100.0,"predicted_ms":400.0}}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
|
||||||
|
`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, "test-model", metrics[0].Model)
|
||||||
|
// When timings data is present, it takes precedence
|
||||||
|
assert.Equal(t, 10, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 20, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-OK status code does not record metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
|
w.Write([]byte("error"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty response body does not record metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("invalid JSON does not record metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte("not valid json"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("next handler error is propagated", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
expectedErr := assert.AnError
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
return expectedErr
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.Equal(t, expectedErr, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("response without usage or timings does not record metrics", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := `{"result": "ok"}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) {
|
||||||
|
t.Run("captures response body", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
copier := newBodyCopier(ginCtx.Writer)
|
||||||
|
|
||||||
|
testData := []byte("test response body")
|
||||||
|
n, err := copier.Write(testData)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, len(testData), n)
|
||||||
|
assert.Equal(t, testData, copier.body.Bytes())
|
||||||
|
assert.Equal(t, string(testData), rec.Body.String())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets start time on first write", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
copier := newBodyCopier(ginCtx.Writer)
|
||||||
|
|
||||||
|
assert.True(t, copier.StartTime().IsZero())
|
||||||
|
|
||||||
|
copier.Write([]byte("test"))
|
||||||
|
|
||||||
|
assert.False(t, copier.StartTime().IsZero())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves headers", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
copier := newBodyCopier(ginCtx.Writer)
|
||||||
|
|
||||||
|
copier.Header().Set("X-Test", "value")
|
||||||
|
|
||||||
|
assert.Equal(t, "value", rec.Header().Get("X-Test"))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves status code", func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
copier := newBodyCopier(ginCtx.Writer)
|
||||||
|
|
||||||
|
copier.WriteHeader(http.StatusCreated)
|
||||||
|
|
||||||
|
// Gin's ResponseWriter tracks status internally
|
||||||
|
assert.Equal(t, http.StatusCreated, copier.Status())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_Concurrent(t *testing.T) {
|
||||||
|
t.Run("concurrent addMetrics is safe", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 1000)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
numGoroutines := 10
|
||||||
|
metricsPerGoroutine := 100
|
||||||
|
|
||||||
|
for i := 0; i < numGoroutines; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(id int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < metricsPerGoroutine; j++ {
|
||||||
|
mm.addMetrics(TokenMetrics{
|
||||||
|
Model: "test-model",
|
||||||
|
InputTokens: id*1000 + j,
|
||||||
|
OutputTokens: j,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, numGoroutines*metricsPerGoroutine, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("concurrent reads and writes are safe", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 100)
|
||||||
|
|
||||||
|
done := make(chan bool)
|
||||||
|
|
||||||
|
// Writer goroutine
|
||||||
|
go func() {
|
||||||
|
for i := 0; i < 50; i++ {
|
||||||
|
mm.addMetrics(TokenMetrics{Model: "test-model"})
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
}
|
||||||
|
done <- true
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Multiple reader goroutines
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 20; j++ {
|
||||||
|
_ = mm.getMetrics()
|
||||||
|
_, _ = mm.getMetricsJSON()
|
||||||
|
time.Sleep(2 * time.Millisecond)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
<-done
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
// Final check
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 50, len(metrics))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_ParseMetrics(t *testing.T) {
|
||||||
|
t.Run("prefers timings over usage data", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
// Timings should take precedence over usage
|
||||||
|
responseBody := `{
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 50,
|
||||||
|
"completion_tokens": 25
|
||||||
|
},
|
||||||
|
"timings": {
|
||||||
|
"prompt_n": 100,
|
||||||
|
"predicted_n": 50,
|
||||||
|
"prompt_per_second": 150.5,
|
||||||
|
"predicted_per_second": 25.5,
|
||||||
|
"prompt_ms": 500.0,
|
||||||
|
"predicted_ms": 1500.0
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
// Should use timings values, not usage values
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles missing cache_n in timings", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := `{
|
||||||
|
"timings": {
|
||||||
|
"prompt_n": 100,
|
||||||
|
"predicted_n": 50,
|
||||||
|
"prompt_per_second": 150.5,
|
||||||
|
"predicted_per_second": 25.5,
|
||||||
|
"prompt_ms": 500.0,
|
||||||
|
"predicted_ms": 1500.0
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMetricsMonitor_StreamingResponse(t *testing.T) {
|
||||||
|
t.Run("finds metrics in last valid SSE data", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
// Metrics should be found in the last data line before [DONE]
|
||||||
|
responseBody := `data: {"choices":[{"text":"First"}]}
|
||||||
|
|
||||||
|
data: {"choices":[{"text":"Second"}]}
|
||||||
|
|
||||||
|
data: {"usage":{"prompt_tokens":100,"completion_tokens":50}}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
|
||||||
|
`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 1, len(metrics))
|
||||||
|
assert.Equal(t, 100, metrics[0].InputTokens)
|
||||||
|
assert.Equal(t, 50, metrics[0].OutputTokens)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles streaming with no valid JSON", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := `data: not json
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
|
||||||
|
`
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
assert.NoError(t, err) // Errors after response is sent are logged, not returned
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles empty streaming response", func(t *testing.T) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 10)
|
||||||
|
|
||||||
|
responseBody := ``
|
||||||
|
|
||||||
|
nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error {
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
w.Write([]byte(responseBody))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
req := httptest.NewRequest("POST", "/test", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
ginCtx, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler)
|
||||||
|
// Empty body should not trigger WrapHandler processing
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := mm.getMetrics()
|
||||||
|
assert.Equal(t, 0, len(metrics))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) {
|
||||||
|
mm := newMetricsMonitor(testLogger, 1000)
|
||||||
|
|
||||||
|
metric := TokenMetrics{
|
||||||
|
Model: "test-model",
|
||||||
|
CachedTokens: 100,
|
||||||
|
InputTokens: 500,
|
||||||
|
OutputTokens: 250,
|
||||||
|
PromptPerSecond: 1200.5,
|
||||||
|
TokensPerSecond: 45.8,
|
||||||
|
DurationMs: 5000,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mm.addMetrics(metric)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) {
|
||||||
|
// Test performance with a smaller buffer where wrapping occurs more frequently
|
||||||
|
mm := newMetricsMonitor(testLogger, 100)
|
||||||
|
|
||||||
|
metric := TokenMetrics{
|
||||||
|
Model: "test-model",
|
||||||
|
CachedTokens: 100,
|
||||||
|
InputTokens: 500,
|
||||||
|
OutputTokens: 250,
|
||||||
|
PromptPerSecond: 1200.5,
|
||||||
|
TokensPerSecond: 45.8,
|
||||||
|
DurationMs: 5000,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
mm.addMetrics(metric)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -36,7 +36,7 @@ type ProxyManager struct {
|
|||||||
upstreamLogger *LogMonitor
|
upstreamLogger *LogMonitor
|
||||||
muxLogger *LogMonitor
|
muxLogger *LogMonitor
|
||||||
|
|
||||||
metricsMonitor *MetricsMonitor
|
metricsMonitor *metricsMonitor
|
||||||
|
|
||||||
processGroups map[string]*ProcessGroup
|
processGroups map[string]*ProcessGroup
|
||||||
|
|
||||||
@@ -75,6 +75,13 @@ func New(config config.Config) *ProxyManager {
|
|||||||
|
|
||||||
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
shutdownCtx, shutdownCancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
var maxMetrics int
|
||||||
|
if config.MetricsMaxInMemory <= 0 {
|
||||||
|
maxMetrics = 1000 // Default fallback
|
||||||
|
} else {
|
||||||
|
maxMetrics = config.MetricsMaxInMemory
|
||||||
|
}
|
||||||
|
|
||||||
pm := &ProxyManager{
|
pm := &ProxyManager{
|
||||||
config: config,
|
config: config,
|
||||||
ginEngine: gin.New(),
|
ginEngine: gin.New(),
|
||||||
@@ -83,7 +90,7 @@ func New(config config.Config) *ProxyManager {
|
|||||||
muxLogger: stdoutLogger,
|
muxLogger: stdoutLogger,
|
||||||
upstreamLogger: upstreamLogger,
|
upstreamLogger: upstreamLogger,
|
||||||
|
|
||||||
metricsMonitor: NewMetricsMonitor(&config),
|
metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics),
|
||||||
|
|
||||||
processGroups: make(map[string]*ProcessGroup),
|
processGroups: make(map[string]*ProcessGroup),
|
||||||
|
|
||||||
@@ -193,27 +200,25 @@ func (pm *ProxyManager) setupGinEngine() {
|
|||||||
c.Next()
|
c.Next()
|
||||||
})
|
})
|
||||||
|
|
||||||
mm := MetricsMiddleware(pm)
|
|
||||||
|
|
||||||
// Set up routes using the Gin engine
|
// Set up routes using the Gin engine
|
||||||
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||||
// Support legacy /v1/completions api, see issue #12
|
// Support legacy /v1/completions api, see issue #12
|
||||||
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
||||||
|
|
||||||
// Support embeddings and reranking
|
// Support embeddings and reranking
|
||||||
pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||||
|
|
||||||
// llama-server's /reranking endpoint + aliases
|
// llama-server's /reranking endpoint + aliases
|
||||||
pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/reranking", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/rerank", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler)
|
||||||
pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/reranking", pm.proxyOAIHandler)
|
||||||
|
|
||||||
// llama-server's /infill endpoint for code infilling
|
// llama-server's /infill endpoint for code infilling
|
||||||
pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/infill", pm.proxyOAIHandler)
|
||||||
|
|
||||||
// llama-server's /completion endpoint
|
// llama-server's /completion endpoint
|
||||||
pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler)
|
pm.ginEngine.POST("/completion", pm.proxyOAIHandler)
|
||||||
|
|
||||||
// Support audio/speech endpoint
|
// Support audio/speech endpoint
|
||||||
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler)
|
||||||
@@ -474,8 +479,23 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// rewrite the path
|
// rewrite the path
|
||||||
|
originalPath := c.Request.URL.Path
|
||||||
c.Request.URL.Path = remainingPath
|
c.Request.URL.Path = remainingPath
|
||||||
processGroup.ProxyRequest(realModelName, c.Writer, c.Request)
|
|
||||||
|
// attempt to record metrics if it is a POST request
|
||||||
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
|
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||||
@@ -535,10 +555,18 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
c.Request.ContentLength = int64(len(bodyBytes))
|
c.Request.ContentLength = int64(len(bodyBytes))
|
||||||
|
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
if pm.metricsMonitor != nil && c.Request.Method == "POST" {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil {
|
||||||
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error()))
|
||||||
return
|
pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||||
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
|||||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||||
sendModels()
|
sendModels()
|
||||||
sendMetrics(pm.metricsMonitor.GetMetrics())
|
sendMetrics(pm.metricsMonitor.getMetrics())
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -198,7 +198,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||||
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
|
jsonData, err := pm.metricsMonitor.getMetricsJSON()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -911,76 +911,6 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
|||||||
// t.Logf("%v", response)
|
// t.Logf("%v", response)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
|
||||||
config := config.AddDefaultGroupToConfig(config.Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]config.ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
},
|
|
||||||
LogLevel: "error",
|
|
||||||
})
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
||||||
|
|
||||||
// Make a non-streaming request
|
|
||||||
reqBody := `{"model":"model1", "stream": false}`
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
|
||||||
w := CreateTestResponseRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
|
|
||||||
// Check that metrics were recorded
|
|
||||||
metrics := proxy.metricsMonitor.GetMetrics()
|
|
||||||
if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the last metric has the correct model
|
|
||||||
lastMetric := metrics[len(metrics)-1]
|
|
||||||
assert.Equal(t, "model1", lastMetric.Model)
|
|
||||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
|
||||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
|
||||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
|
||||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
|
||||||
config := config.AddDefaultGroupToConfig(config.Config{
|
|
||||||
HealthCheckTimeout: 15,
|
|
||||||
Models: map[string]config.ModelConfig{
|
|
||||||
"model1": getTestSimpleResponderConfig("model1"),
|
|
||||||
},
|
|
||||||
LogLevel: "error",
|
|
||||||
})
|
|
||||||
|
|
||||||
proxy := New(config)
|
|
||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
|
||||||
|
|
||||||
// Make a streaming request
|
|
||||||
reqBody := `{"model":"model1", "stream": true}`
|
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
|
||||||
w := CreateTestResponseRecorder()
|
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
|
||||||
|
|
||||||
// Check that metrics were recorded
|
|
||||||
metrics := proxy.metricsMonitor.GetMetrics()
|
|
||||||
if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the last metric has the correct model
|
|
||||||
lastMetric := metrics[len(metrics)-1]
|
|
||||||
assert.Equal(t, "model1", lastMetric.Model)
|
|
||||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
|
||||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
|
||||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
|
||||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
func TestProxyManager_HealthEndpoint(t *testing.T) {
|
||||||
config := config.AddDefaultGroupToConfig(config.Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
|
|||||||
Reference in New Issue
Block a user