diff --git a/proxy/metrics_middleware.go b/proxy/metrics_middleware.go deleted file mode 100644 index 734d75a..0000000 --- a/proxy/metrics_middleware.go +++ /dev/null @@ -1,184 +0,0 @@ -package proxy - -import ( - "bytes" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/gin-gonic/gin" - "github.com/tidwall/gjson" -) - -type MetricsRecorder struct { - metricsMonitor *MetricsMonitor - realModelName string - // isStreaming bool - startTime time.Time -} - -// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests -func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc { - return func(c *gin.Context) { - bodyBytes, err := io.ReadAll(c.Request.Body) - if err != nil { - pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body") - c.Abort() - return - } - c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) - - requestedModel := gjson.GetBytes(bodyBytes, "model").String() - if requestedModel == "" { - pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") - c.Abort() - return - } - - realModelName, found := pm.config.RealModelName(requestedModel) - if !found { - pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel)) - c.Abort() - return - } - - writer := &MetricsResponseWriter{ - ResponseWriter: c.Writer, - metricsRecorder: &MetricsRecorder{ - metricsMonitor: pm.metricsMonitor, - realModelName: realModelName, - startTime: time.Now(), - }, - } - c.Writer = writer - c.Next() - - // check for streaming response - if strings.Contains(c.Writer.Header().Get("Content-Type"), "text/event-stream") { - writer.metricsRecorder.processStreamingResponse(writer.body) - } else { - writer.metricsRecorder.processNonStreamingResponse(writer.body) - } - } -} - -func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) bool { - usage := jsonData.Get("usage") - timings := jsonData.Get("timings") - if !usage.Exists() && !timings.Exists() { - return false - } - - // default values - cachedTokens := -1 // unknown or missing data - outputTokens := 0 - inputTokens := 0 - - // timings data - tokensPerSecond := -1.0 - promptPerSecond := -1.0 - durationMs := int(time.Since(rec.startTime).Milliseconds()) - - if usage.Exists() { - outputTokens = int(jsonData.Get("usage.completion_tokens").Int()) - inputTokens = int(jsonData.Get("usage.prompt_tokens").Int()) - } - - // use llama-server's timing data for tok/sec and duration as it is more accurate - if timings.Exists() { - inputTokens = int(jsonData.Get("timings.prompt_n").Int()) - outputTokens = int(jsonData.Get("timings.predicted_n").Int()) - promptPerSecond = jsonData.Get("timings.prompt_per_second").Float() - tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float() - durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float()) - - if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() { - cachedTokens = int(cachedValue.Int()) - } - } - - rec.metricsMonitor.addMetrics(TokenMetrics{ - Timestamp: time.Now(), - Model: rec.realModelName, - CachedTokens: cachedTokens, - InputTokens: inputTokens, - OutputTokens: outputTokens, - PromptPerSecond: promptPerSecond, - TokensPerSecond: tokensPerSecond, - DurationMs: durationMs, - }) - - return true -} - -func (rec *MetricsRecorder) processStreamingResponse(body []byte) { - // Iterate **backwards** through the lines looking for the data payload with - // usage data - lines := bytes.Split(body, []byte("\n")) - - for i := len(lines) - 1; i >= 0; i-- { - line := bytes.TrimSpace(lines[i]) - if len(line) == 0 { - continue - } - - // SSE payload always follows "data:" - prefix := []byte("data:") - if !bytes.HasPrefix(line, prefix) { - continue - } - data := bytes.TrimSpace(line[len(prefix):]) - - if len(data) == 0 { - continue - } - - if bytes.Equal(data, []byte("[DONE]")) { - // [DONE] line itself contains nothing of interest. - continue - } - - if gjson.ValidBytes(data) { - if rec.parseAndRecordMetrics(gjson.ParseBytes(data)) { - return // short circuit if a metric was recorded - } - } - } -} - -func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) { - if len(body) == 0 { - return - } - - // Parse JSON to extract usage information - if gjson.ValidBytes(body) { - rec.parseAndRecordMetrics(gjson.ParseBytes(body)) - } -} - -// MetricsResponseWriter captures the entire response for non-streaming -type MetricsResponseWriter struct { - gin.ResponseWriter - body []byte - metricsRecorder *MetricsRecorder -} - -func (w *MetricsResponseWriter) Write(b []byte) (int, error) { - n, err := w.ResponseWriter.Write(b) - if err != nil { - return n, err - } - w.body = append(w.body, b...) - return n, nil -} - -func (w *MetricsResponseWriter) WriteHeader(statusCode int) { - w.ResponseWriter.WriteHeader(statusCode) -} - -func (w *MetricsResponseWriter) Header() http.Header { - return w.ResponseWriter.Header() -} diff --git a/proxy/metrics_monitor.go b/proxy/metrics_monitor.go index 826870f..360eb22 100644 --- a/proxy/metrics_monitor.go +++ b/proxy/metrics_monitor.go @@ -1,12 +1,18 @@ package proxy import ( + "bytes" "encoding/json" + "fmt" + "io" + "net/http" + "strings" "sync" "time" + "github.com/gin-gonic/gin" "github.com/mostlygeek/llama-swap/event" - "github.com/mostlygeek/llama-swap/proxy/config" + "github.com/tidwall/gjson" ) // TokenMetrics represents parsed token statistics from llama-server logs @@ -31,21 +37,18 @@ func (e TokenMetricsEvent) Type() uint32 { return TokenMetricsEventID // defined in events.go } -// MetricsMonitor parses llama-server output for token statistics -type MetricsMonitor struct { +// metricsMonitor parses llama-server output for token statistics +type metricsMonitor struct { mu sync.RWMutex metrics []TokenMetrics maxMetrics int nextID int + logger *LogMonitor } -func NewMetricsMonitor(config *config.Config) *MetricsMonitor { - maxMetrics := config.MetricsMaxInMemory - if maxMetrics <= 0 { - maxMetrics = 1000 // Default fallback - } - - mp := &MetricsMonitor{ +func newMetricsMonitor(logger *LogMonitor, maxMetrics int) *metricsMonitor { + mp := &metricsMonitor{ + logger: logger, maxMetrics: maxMetrics, } @@ -53,7 +56,7 @@ func NewMetricsMonitor(config *config.Config) *MetricsMonitor { } // addMetrics adds a new metric to the collection and publishes an event -func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) { +func (mp *metricsMonitor) addMetrics(metric TokenMetrics) { mp.mu.Lock() defer mp.mu.Unlock() @@ -66,8 +69,8 @@ func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) { event.Emit(TokenMetricsEvent{Metrics: metric}) } -// GetMetrics returns a copy of the current metrics -func (mp *MetricsMonitor) GetMetrics() []TokenMetrics { +// getMetrics returns a copy of the current metrics +func (mp *metricsMonitor) getMetrics() []TokenMetrics { mp.mu.RLock() defer mp.mu.RUnlock() @@ -76,9 +79,189 @@ func (mp *MetricsMonitor) GetMetrics() []TokenMetrics { return result } -// GetMetricsJSON returns metrics as JSON -func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) { +// getMetricsJSON returns metrics as JSON +func (mp *metricsMonitor) getMetricsJSON() ([]byte, error) { mp.mu.RLock() defer mp.mu.RUnlock() return json.Marshal(mp.metrics) } + +// wrapHandler wraps the proxy handler to extract token metrics +// if wrapHandler returns an error it is safe to assume that no +// data was sent to the client +func (mp *metricsMonitor) wrapHandler( + modelID string, + writer gin.ResponseWriter, + request *http.Request, + next func(modelID string, w http.ResponseWriter, r *http.Request) error, +) error { + recorder := newBodyCopier(writer) + if err := next(modelID, recorder, request); err != nil { + return err + } + + // after this point we have to assume that data was sent to the client + // and we can only log errors but not send them to clients + + if recorder.Status() != http.StatusOK { + mp.logger.Warnf("metrics skipped, HTTP status=%d, path=%s", recorder.Status(), request.URL.Path) + return nil + } + + body := recorder.body.Bytes() + if len(body) == 0 { + mp.logger.Warn("metrics skipped, empty body") + return nil + } + + if strings.Contains(recorder.Header().Get("Content-Type"), "text/event-stream") { + if tm, err := processStreamingResponse(modelID, recorder.StartTime(), body); err != nil { + mp.logger.Warnf("error processing streaming response: %v, path=%s", err, request.URL.Path) + } else { + mp.addMetrics(tm) + } + } else { + if gjson.ValidBytes(body) { + if tm, err := parseMetrics(modelID, recorder.StartTime(), gjson.ParseBytes(body)); err != nil { + mp.logger.Warnf("error parsing metrics: %v, path=%s", err, request.URL.Path) + } else { + mp.addMetrics(tm) + } + } else { + mp.logger.Warnf("metrics skipped, invalid JSON in response body path=%s", request.URL.Path) + } + } + + return nil +} + +func processStreamingResponse(modelID string, start time.Time, body []byte) (TokenMetrics, error) { + // Iterate **backwards** through the body looking for the data payload with + // usage data. This avoids allocating a slice of all lines via bytes.Split. + + // Start from the end of the body and scan backwards for newlines + pos := len(body) + for pos > 0 { + // Find the previous newline (or start of body) + lineStart := bytes.LastIndexByte(body[:pos], '\n') + if lineStart == -1 { + lineStart = 0 + } else { + lineStart++ // Move past the newline + } + + line := bytes.TrimSpace(body[lineStart:pos]) + pos = lineStart - 1 // Move position before the newline for next iteration + + if len(line) == 0 { + continue + } + + // SSE payload always follows "data:" + prefix := []byte("data:") + if !bytes.HasPrefix(line, prefix) { + continue + } + data := bytes.TrimSpace(line[len(prefix):]) + + if len(data) == 0 { + continue + } + + if bytes.Equal(data, []byte("[DONE]")) { + // [DONE] line itself contains nothing of interest. + continue + } + + if gjson.ValidBytes(data) { + return parseMetrics(modelID, start, gjson.ParseBytes(data)) + } + } + + return TokenMetrics{}, fmt.Errorf("no valid JSON data found in stream") +} + +func parseMetrics(modelID string, start time.Time, jsonData gjson.Result) (TokenMetrics, error) { + usage := jsonData.Get("usage") + timings := jsonData.Get("timings") + if !usage.Exists() && !timings.Exists() { + return TokenMetrics{}, fmt.Errorf("no usage or timings data found") + } + // default values + cachedTokens := -1 // unknown or missing data + outputTokens := 0 + inputTokens := 0 + + // timings data + tokensPerSecond := -1.0 + promptPerSecond := -1.0 + durationMs := int(time.Since(start).Milliseconds()) + + if usage.Exists() { + outputTokens = int(jsonData.Get("usage.completion_tokens").Int()) + inputTokens = int(jsonData.Get("usage.prompt_tokens").Int()) + } + + // use llama-server's timing data for tok/sec and duration as it is more accurate + if timings.Exists() { + inputTokens = int(jsonData.Get("timings.prompt_n").Int()) + outputTokens = int(jsonData.Get("timings.predicted_n").Int()) + promptPerSecond = jsonData.Get("timings.prompt_per_second").Float() + tokensPerSecond = jsonData.Get("timings.predicted_per_second").Float() + durationMs = int(jsonData.Get("timings.prompt_ms").Float() + jsonData.Get("timings.predicted_ms").Float()) + + if cachedValue := jsonData.Get("timings.cache_n"); cachedValue.Exists() { + cachedTokens = int(cachedValue.Int()) + } + } + + return TokenMetrics{ + Timestamp: time.Now(), + Model: modelID, + CachedTokens: cachedTokens, + InputTokens: inputTokens, + OutputTokens: outputTokens, + PromptPerSecond: promptPerSecond, + TokensPerSecond: tokensPerSecond, + DurationMs: durationMs, + }, nil +} + +// responseBodyCopier records the response body and writes to the original response writer +// while also capturing it in a buffer for later processing +type responseBodyCopier struct { + gin.ResponseWriter + body *bytes.Buffer + tee io.Writer + start time.Time +} + +func newBodyCopier(w gin.ResponseWriter) *responseBodyCopier { + bodyBuffer := &bytes.Buffer{} + return &responseBodyCopier{ + ResponseWriter: w, + body: bodyBuffer, + tee: io.MultiWriter(w, bodyBuffer), + } +} + +func (w *responseBodyCopier) Write(b []byte) (int, error) { + if w.start.IsZero() { + w.start = time.Now() + } + + // Single write operation that writes to both the response and buffer + return w.tee.Write(b) +} + +func (w *responseBodyCopier) WriteHeader(statusCode int) { + w.ResponseWriter.WriteHeader(statusCode) +} + +func (w *responseBodyCopier) Header() http.Header { + return w.ResponseWriter.Header() +} + +func (w *responseBodyCopier) StartTime() time.Time { + return w.start +} diff --git a/proxy/metrics_monitor_test.go b/proxy/metrics_monitor_test.go new file mode 100644 index 0000000..fb35388 --- /dev/null +++ b/proxy/metrics_monitor_test.go @@ -0,0 +1,693 @@ +package proxy + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/gin-gonic/gin" + "github.com/mostlygeek/llama-swap/event" + "github.com/stretchr/testify/assert" +) + +func TestMetricsMonitor_AddMetrics(t *testing.T) { + t.Run("adds metrics and assigns ID", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + metric := TokenMetrics{ + Model: "test-model", + InputTokens: 100, + OutputTokens: 50, + } + + mm.addMetrics(metric) + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, 0, metrics[0].ID) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 100, metrics[0].InputTokens) + assert.Equal(t, 50, metrics[0].OutputTokens) + }) + + t.Run("increments ID for each metric", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + for i := 0; i < 5; i++ { + mm.addMetrics(TokenMetrics{Model: "model"}) + } + + metrics := mm.getMetrics() + assert.Equal(t, 5, len(metrics)) + for i := 0; i < 5; i++ { + assert.Equal(t, i, metrics[i].ID) + } + }) + + t.Run("respects max metrics limit", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 3) + + // Add 5 metrics + for i := 0; i < 5; i++ { + mm.addMetrics(TokenMetrics{ + Model: "model", + InputTokens: i, + }) + } + + metrics := mm.getMetrics() + assert.Equal(t, 3, len(metrics)) + + // Should keep the last 3 metrics (IDs 2, 3, 4) + assert.Equal(t, 2, metrics[0].ID) + assert.Equal(t, 3, metrics[1].ID) + assert.Equal(t, 4, metrics[2].ID) + }) + + t.Run("emits TokenMetricsEvent", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + receivedEvent := make(chan TokenMetricsEvent, 1) + cancel := event.On(func(e TokenMetricsEvent) { + receivedEvent <- e + }) + defer cancel() + + metric := TokenMetrics{ + Model: "test-model", + InputTokens: 100, + OutputTokens: 50, + } + + mm.addMetrics(metric) + + select { + case evt := <-receivedEvent: + assert.Equal(t, 0, evt.Metrics.ID) + assert.Equal(t, "test-model", evt.Metrics.Model) + assert.Equal(t, 100, evt.Metrics.InputTokens) + assert.Equal(t, 50, evt.Metrics.OutputTokens) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for event") + } + }) +} + +func TestMetricsMonitor_GetMetrics(t *testing.T) { + t.Run("returns empty slice when no metrics", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + metrics := mm.getMetrics() + assert.NotNil(t, metrics) + assert.Equal(t, 0, len(metrics)) + }) + + t.Run("returns copy of metrics", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + mm.addMetrics(TokenMetrics{Model: "model1"}) + mm.addMetrics(TokenMetrics{Model: "model2"}) + + metrics1 := mm.getMetrics() + metrics2 := mm.getMetrics() + + // Verify we got copies + assert.Equal(t, 2, len(metrics1)) + assert.Equal(t, 2, len(metrics2)) + + // Modify the returned slice shouldn't affect the original + metrics1[0].Model = "modified" + metrics3 := mm.getMetrics() + assert.Equal(t, "model1", metrics3[0].Model) + }) +} + +func TestMetricsMonitor_GetMetricsJSON(t *testing.T) { + t.Run("returns valid JSON for empty metrics", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + jsonData, err := mm.getMetricsJSON() + assert.NoError(t, err) + assert.NotNil(t, jsonData) + + var metrics []TokenMetrics + err = json.Unmarshal(jsonData, &metrics) + assert.NoError(t, err) + assert.Equal(t, 0, len(metrics)) + }) + + t.Run("returns valid JSON with metrics", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + mm.addMetrics(TokenMetrics{ + Model: "model1", + InputTokens: 100, + OutputTokens: 50, + TokensPerSecond: 25.5, + }) + mm.addMetrics(TokenMetrics{ + Model: "model2", + InputTokens: 200, + OutputTokens: 100, + TokensPerSecond: 30.0, + }) + + jsonData, err := mm.getMetricsJSON() + assert.NoError(t, err) + + var metrics []TokenMetrics + err = json.Unmarshal(jsonData, &metrics) + assert.NoError(t, err) + assert.Equal(t, 2, len(metrics)) + assert.Equal(t, "model1", metrics[0].Model) + assert.Equal(t, "model2", metrics[1].Model) + }) +} + +func TestMetricsMonitor_WrapHandler(t *testing.T) { + t.Run("successful non-streaming request with usage data", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + responseBody := `{ + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50 + } + }` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 100, metrics[0].InputTokens) + assert.Equal(t, 50, metrics[0].OutputTokens) + }) + + t.Run("successful request with timings data", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + responseBody := `{ + "timings": { + "prompt_n": 100, + "predicted_n": 50, + "prompt_per_second": 150.5, + "predicted_per_second": 25.5, + "prompt_ms": 500.0, + "predicted_ms": 1500.0, + "cache_n": 20 + } + }` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + assert.Equal(t, 100, metrics[0].InputTokens) + assert.Equal(t, 50, metrics[0].OutputTokens) + assert.Equal(t, 20, metrics[0].CachedTokens) + assert.Equal(t, 150.5, metrics[0].PromptPerSecond) + assert.Equal(t, 25.5, metrics[0].TokensPerSecond) + assert.Equal(t, 2000, metrics[0].DurationMs) // 500 + 1500 + }) + + t.Run("streaming request with SSE format", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + // Note: SSE format requires proper line breaks - each data line followed by blank line + responseBody := `data: {"choices":[{"text":"Hello"}]} + +data: {"choices":[{"text":" World"}]} + +data: {"usage":{"prompt_tokens":10,"completion_tokens":20},"timings":{"prompt_n":10,"predicted_n":20,"prompt_per_second":100.0,"predicted_per_second":50.0,"prompt_ms":100.0,"predicted_ms":400.0}} + +data: [DONE] + +` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, "test-model", metrics[0].Model) + // When timings data is present, it takes precedence + assert.Equal(t, 10, metrics[0].InputTokens) + assert.Equal(t, 20, metrics[0].OutputTokens) + }) + + t.Run("non-OK status code does not record metrics", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte("error")) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 0, len(metrics)) + }) + + t.Run("empty response body does not record metrics", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.WriteHeader(http.StatusOK) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 0, len(metrics)) + }) + + t.Run("invalid JSON does not record metrics", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte("not valid json")) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) // Errors after response is sent are logged, not returned + + metrics := mm.getMetrics() + assert.Equal(t, 0, len(metrics)) + }) + + t.Run("next handler error is propagated", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + expectedErr := assert.AnError + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + return expectedErr + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.Equal(t, expectedErr, err) + + metrics := mm.getMetrics() + assert.Equal(t, 0, len(metrics)) + }) + + t.Run("response without usage or timings does not record metrics", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + responseBody := `{"result": "ok"}` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) // Errors after response is sent are logged, not returned + + metrics := mm.getMetrics() + assert.Equal(t, 0, len(metrics)) + }) +} + +func TestMetricsMonitor_ResponseBodyCopier(t *testing.T) { + t.Run("captures response body", func(t *testing.T) { + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + copier := newBodyCopier(ginCtx.Writer) + + testData := []byte("test response body") + n, err := copier.Write(testData) + + assert.NoError(t, err) + assert.Equal(t, len(testData), n) + assert.Equal(t, testData, copier.body.Bytes()) + assert.Equal(t, string(testData), rec.Body.String()) + }) + + t.Run("sets start time on first write", func(t *testing.T) { + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + copier := newBodyCopier(ginCtx.Writer) + + assert.True(t, copier.StartTime().IsZero()) + + copier.Write([]byte("test")) + + assert.False(t, copier.StartTime().IsZero()) + }) + + t.Run("preserves headers", func(t *testing.T) { + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + copier := newBodyCopier(ginCtx.Writer) + + copier.Header().Set("X-Test", "value") + + assert.Equal(t, "value", rec.Header().Get("X-Test")) + }) + + t.Run("preserves status code", func(t *testing.T) { + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + copier := newBodyCopier(ginCtx.Writer) + + copier.WriteHeader(http.StatusCreated) + + // Gin's ResponseWriter tracks status internally + assert.Equal(t, http.StatusCreated, copier.Status()) + }) +} + +func TestMetricsMonitor_Concurrent(t *testing.T) { + t.Run("concurrent addMetrics is safe", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 1000) + + var wg sync.WaitGroup + numGoroutines := 10 + metricsPerGoroutine := 100 + + for i := 0; i < numGoroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < metricsPerGoroutine; j++ { + mm.addMetrics(TokenMetrics{ + Model: "test-model", + InputTokens: id*1000 + j, + OutputTokens: j, + }) + } + }(i) + } + + wg.Wait() + + metrics := mm.getMetrics() + assert.Equal(t, numGoroutines*metricsPerGoroutine, len(metrics)) + }) + + t.Run("concurrent reads and writes are safe", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 100) + + done := make(chan bool) + + // Writer goroutine + go func() { + for i := 0; i < 50; i++ { + mm.addMetrics(TokenMetrics{Model: "test-model"}) + time.Sleep(1 * time.Millisecond) + } + done <- true + }() + + // Multiple reader goroutines + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 20; j++ { + _ = mm.getMetrics() + _, _ = mm.getMetricsJSON() + time.Sleep(2 * time.Millisecond) + } + }() + } + + <-done + wg.Wait() + + // Final check + metrics := mm.getMetrics() + assert.Equal(t, 50, len(metrics)) + }) +} + +func TestMetricsMonitor_ParseMetrics(t *testing.T) { + t.Run("prefers timings over usage data", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + // Timings should take precedence over usage + responseBody := `{ + "usage": { + "prompt_tokens": 50, + "completion_tokens": 25 + }, + "timings": { + "prompt_n": 100, + "predicted_n": 50, + "prompt_per_second": 150.5, + "predicted_per_second": 25.5, + "prompt_ms": 500.0, + "predicted_ms": 1500.0 + } + }` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + // Should use timings values, not usage values + assert.Equal(t, 100, metrics[0].InputTokens) + assert.Equal(t, 50, metrics[0].OutputTokens) + }) + + t.Run("handles missing cache_n in timings", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + responseBody := `{ + "timings": { + "prompt_n": 100, + "predicted_n": 50, + "prompt_per_second": 150.5, + "predicted_per_second": 25.5, + "prompt_ms": 500.0, + "predicted_ms": 1500.0 + } + }` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, -1, metrics[0].CachedTokens) // Default value when not present + }) +} + +func TestMetricsMonitor_StreamingResponse(t *testing.T) { + t.Run("finds metrics in last valid SSE data", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + // Metrics should be found in the last data line before [DONE] + responseBody := `data: {"choices":[{"text":"First"}]} + +data: {"choices":[{"text":"Second"}]} + +data: {"usage":{"prompt_tokens":100,"completion_tokens":50}} + +data: [DONE] + +` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 1, len(metrics)) + assert.Equal(t, 100, metrics[0].InputTokens) + assert.Equal(t, 50, metrics[0].OutputTokens) + }) + + t.Run("handles streaming with no valid JSON", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + responseBody := `data: not json + +data: [DONE] + +` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + assert.NoError(t, err) // Errors after response is sent are logged, not returned + + metrics := mm.getMetrics() + assert.Equal(t, 0, len(metrics)) + }) + + t.Run("handles empty streaming response", func(t *testing.T) { + mm := newMetricsMonitor(testLogger, 10) + + responseBody := `` + + nextHandler := func(modelID string, w http.ResponseWriter, r *http.Request) error { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + w.Write([]byte(responseBody)) + return nil + } + + req := httptest.NewRequest("POST", "/test", nil) + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + + err := mm.wrapHandler("test-model", ginCtx.Writer, req, nextHandler) + // Empty body should not trigger WrapHandler processing + assert.NoError(t, err) + + metrics := mm.getMetrics() + assert.Equal(t, 0, len(metrics)) + }) +} + +// Benchmark tests +func BenchmarkMetricsMonitor_AddMetrics(b *testing.B) { + mm := newMetricsMonitor(testLogger, 1000) + + metric := TokenMetrics{ + Model: "test-model", + CachedTokens: 100, + InputTokens: 500, + OutputTokens: 250, + PromptPerSecond: 1200.5, + TokensPerSecond: 45.8, + DurationMs: 5000, + Timestamp: time.Now(), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + mm.addMetrics(metric) + } +} + +func BenchmarkMetricsMonitor_AddMetrics_SmallBuffer(b *testing.B) { + // Test performance with a smaller buffer where wrapping occurs more frequently + mm := newMetricsMonitor(testLogger, 100) + + metric := TokenMetrics{ + Model: "test-model", + CachedTokens: 100, + InputTokens: 500, + OutputTokens: 250, + PromptPerSecond: 1200.5, + TokensPerSecond: 45.8, + DurationMs: 5000, + Timestamp: time.Now(), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + mm.addMetrics(metric) + } +} diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 9888410..683d64b 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -36,7 +36,7 @@ type ProxyManager struct { upstreamLogger *LogMonitor muxLogger *LogMonitor - metricsMonitor *MetricsMonitor + metricsMonitor *metricsMonitor processGroups map[string]*ProcessGroup @@ -75,6 +75,13 @@ func New(config config.Config) *ProxyManager { shutdownCtx, shutdownCancel := context.WithCancel(context.Background()) + var maxMetrics int + if config.MetricsMaxInMemory <= 0 { + maxMetrics = 1000 // Default fallback + } else { + maxMetrics = config.MetricsMaxInMemory + } + pm := &ProxyManager{ config: config, ginEngine: gin.New(), @@ -83,7 +90,7 @@ func New(config config.Config) *ProxyManager { muxLogger: stdoutLogger, upstreamLogger: upstreamLogger, - metricsMonitor: NewMetricsMonitor(&config), + metricsMonitor: newMetricsMonitor(proxyLogger, maxMetrics), processGroups: make(map[string]*ProcessGroup), @@ -193,27 +200,25 @@ func (pm *ProxyManager) setupGinEngine() { c.Next() }) - mm := MetricsMiddleware(pm) - // Set up routes using the Gin engine - pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler) + pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler) // Support legacy /v1/completions api, see issue #12 - pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler) + pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler) // Support embeddings and reranking - pm.ginEngine.POST("/v1/embeddings", mm, pm.proxyOAIHandler) + pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler) // llama-server's /reranking endpoint + aliases - pm.ginEngine.POST("/reranking", mm, pm.proxyOAIHandler) - pm.ginEngine.POST("/rerank", mm, pm.proxyOAIHandler) - pm.ginEngine.POST("/v1/rerank", mm, pm.proxyOAIHandler) - pm.ginEngine.POST("/v1/reranking", mm, pm.proxyOAIHandler) + pm.ginEngine.POST("/reranking", pm.proxyOAIHandler) + pm.ginEngine.POST("/rerank", pm.proxyOAIHandler) + pm.ginEngine.POST("/v1/rerank", pm.proxyOAIHandler) + pm.ginEngine.POST("/v1/reranking", pm.proxyOAIHandler) // llama-server's /infill endpoint for code infilling - pm.ginEngine.POST("/infill", mm, pm.proxyOAIHandler) + pm.ginEngine.POST("/infill", pm.proxyOAIHandler) // llama-server's /completion endpoint - pm.ginEngine.POST("/completion", mm, pm.proxyOAIHandler) + pm.ginEngine.POST("/completion", pm.proxyOAIHandler) // Support audio/speech endpoint pm.ginEngine.POST("/v1/audio/speech", pm.proxyOAIHandler) @@ -474,8 +479,23 @@ func (pm *ProxyManager) proxyToUpstream(c *gin.Context) { } // rewrite the path + originalPath := c.Request.URL.Path c.Request.URL.Path = remainingPath - processGroup.ProxyRequest(realModelName, c.Writer, c.Request) + + // attempt to record metrics if it is a POST request + if pm.metricsMonitor != nil && c.Request.Method == "POST" { + if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) + pm.proxyLogger.Errorf("Error proxying wrapped upstream request for model %s, path=%s", realModelName, originalPath) + return + } + } else { + if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) + pm.proxyLogger.Errorf("Error proxying upstream request for model %s, path=%s", realModelName, originalPath) + return + } + } } func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { @@ -535,10 +555,18 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { c.Request.Header.Set("content-length", strconv.Itoa(len(bodyBytes))) c.Request.ContentLength = int64(len(bodyBytes)) - if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { - pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) - pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) - return + if pm.metricsMonitor != nil && c.Request.Method == "POST" { + if err := pm.metricsMonitor.wrapHandler(realModelName, c.Writer, c.Request, processGroup.ProxyRequest); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying metrics wrapped request: %s", err.Error())) + pm.proxyLogger.Errorf("Error Proxying Metrics Wrapped Request for processGroup %s and model %s", processGroup.id, realModelName) + return + } + } else { + if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { + pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) + pm.proxyLogger.Errorf("Error Proxying Request for processGroup %s and model %s", processGroup.id, realModelName) + return + } } } diff --git a/proxy/proxymanager_api.go b/proxy/proxymanager_api.go index 7100e2c..76f6cd3 100644 --- a/proxy/proxymanager_api.go +++ b/proxy/proxymanager_api.go @@ -180,7 +180,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) { sendLogData("proxy", pm.proxyLogger.GetHistory()) sendLogData("upstream", pm.upstreamLogger.GetHistory()) sendModels() - sendMetrics(pm.metricsMonitor.GetMetrics()) + sendMetrics(pm.metricsMonitor.getMetrics()) for { select { @@ -198,7 +198,7 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) { } func (pm *ProxyManager) apiGetMetrics(c *gin.Context) { - jsonData, err := pm.metricsMonitor.GetMetricsJSON() + jsonData, err := pm.metricsMonitor.getMetricsJSON() if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"}) return diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index c4eb6cc..1bc5a12 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -911,76 +911,6 @@ func TestProxyManager_FiltersStripParams(t *testing.T) { // t.Logf("%v", response) } -func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) - - proxy := New(config) - defer proxy.StopProcesses(StopWaitForInflightRequest) - - // Make a non-streaming request - reqBody := `{"model":"model1", "stream": false}` - req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - - // Check that metrics were recorded - metrics := proxy.metricsMonitor.GetMetrics() - if !assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request") { - return - } - - // Verify the last metric has the correct model - lastMetric := metrics[len(metrics)-1] - assert.Equal(t, "model1", lastMetric.Model) - assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25") - assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10") - assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0") - assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0") -} - -func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) { - config := config.AddDefaultGroupToConfig(config.Config{ - HealthCheckTimeout: 15, - Models: map[string]config.ModelConfig{ - "model1": getTestSimpleResponderConfig("model1"), - }, - LogLevel: "error", - }) - - proxy := New(config) - defer proxy.StopProcesses(StopWaitForInflightRequest) - - // Make a streaming request - reqBody := `{"model":"model1", "stream": true}` - req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody)) - w := CreateTestResponseRecorder() - - proxy.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - - // Check that metrics were recorded - metrics := proxy.metricsMonitor.GetMetrics() - if !assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request") { - return - } - - // Verify the last metric has the correct model - lastMetric := metrics[len(metrics)-1] - assert.Equal(t, "model1", lastMetric.Model) - assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25") - assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10") - assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0") - assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0") -} - func TestProxyManager_HealthEndpoint(t *testing.T) { config := config.AddDefaultGroupToConfig(config.Config{ HealthCheckTimeout: 15,