Add metrics logging for chat completion requests (#195)
- Add token and performance metrics for v1/chat/completions - Add Activity Page in UI - Add /api/metrics endpoint Contributed by @g2mt
This commit is contained in:
@@ -142,6 +142,7 @@ type Config struct {
|
||||
HealthCheckTimeout int `yaml:"healthCheckTimeout"`
|
||||
LogRequests bool `yaml:"logRequests"`
|
||||
LogLevel string `yaml:"logLevel"`
|
||||
MetricsMaxInMemory int `yaml:"metricsMaxInMemory"`
|
||||
Models map[string]ModelConfig `yaml:"models"` /* key is model ID */
|
||||
Profiles map[string][]string `yaml:"profiles"`
|
||||
Groups map[string]GroupConfig `yaml:"groups"` /* key is group ID */
|
||||
@@ -194,6 +195,7 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
||||
HealthCheckTimeout: 120,
|
||||
StartPort: 5800,
|
||||
LogLevel: "info",
|
||||
MetricsMaxInMemory: 1000,
|
||||
}
|
||||
err = yaml.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
|
||||
@@ -196,6 +196,7 @@ groups:
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -193,6 +193,7 @@ groups:
|
||||
},
|
||||
},
|
||||
HealthCheckTimeout: 15,
|
||||
MetricsMaxInMemory: 1000,
|
||||
Profiles: map[string][]string{
|
||||
"test": {"model1", "model2"},
|
||||
},
|
||||
|
||||
@@ -6,6 +6,7 @@ const ProcessStateChangeEventID = 0x01
|
||||
const ChatCompletionStatsEventID = 0x02
|
||||
const ConfigFileChangedEventID = 0x03
|
||||
const LogDataEventID = 0x04
|
||||
const TokenMetricsEventID = 0x05
|
||||
|
||||
type ProcessStateChangeEvent struct {
|
||||
ProcessName string
|
||||
|
||||
153
proxy/metrics_middleware.go
Normal file
153
proxy/metrics_middleware.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// MetricsMiddleware sets up the MetricsResponseWriter for capturing upstream requests
|
||||
func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
|
||||
return
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
realModelName, found := pm.config.RealModelName(requestedModel)
|
||||
if !found {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
|
||||
return
|
||||
}
|
||||
c.Set("ls-real-model-name", realModelName)
|
||||
|
||||
writer := &MetricsResponseWriter{
|
||||
ResponseWriter: c.Writer,
|
||||
metricsRecorder: &MetricsRecorder{
|
||||
metricsMonitor: pm.metricsMonitor,
|
||||
realModelName: realModelName,
|
||||
isStreaming: gjson.GetBytes(bodyBytes, "stream").Bool(),
|
||||
startTime: time.Now(),
|
||||
},
|
||||
}
|
||||
c.Writer = writer
|
||||
c.Next()
|
||||
|
||||
rec := writer.metricsRecorder
|
||||
rec.processBody(writer.body)
|
||||
}
|
||||
}
|
||||
|
||||
type MetricsRecorder struct {
|
||||
metricsMonitor *MetricsMonitor
|
||||
realModelName string
|
||||
isStreaming bool
|
||||
startTime time.Time
|
||||
}
|
||||
|
||||
// processBody handles response processing after request completes
|
||||
func (rec *MetricsRecorder) processBody(body []byte) {
|
||||
if rec.isStreaming {
|
||||
rec.processStreamingResponse(body)
|
||||
} else {
|
||||
rec.processNonStreamingResponse(body)
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) parseAndRecordMetrics(jsonData gjson.Result) {
|
||||
if !jsonData.Get("usage").Exists() {
|
||||
return
|
||||
}
|
||||
|
||||
outputTokens := int(jsonData.Get("usage.completion_tokens").Int())
|
||||
inputTokens := int(jsonData.Get("usage.prompt_tokens").Int())
|
||||
|
||||
if outputTokens > 0 {
|
||||
duration := time.Since(rec.startTime)
|
||||
tokensPerSecond := float64(inputTokens+outputTokens) / duration.Seconds()
|
||||
|
||||
metrics := TokenMetrics{
|
||||
Timestamp: time.Now(),
|
||||
Model: rec.realModelName,
|
||||
InputTokens: inputTokens,
|
||||
OutputTokens: outputTokens,
|
||||
TokensPerSecond: tokensPerSecond,
|
||||
DurationMs: int(duration.Milliseconds()),
|
||||
}
|
||||
rec.metricsMonitor.addMetrics(metrics)
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) processStreamingResponse(body []byte) {
|
||||
lines := bytes.Split(body, []byte("\n"))
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check for SSE data prefix
|
||||
if data, found := bytes.CutPrefix(line, []byte("data:")); found {
|
||||
data = bytes.TrimSpace(data)
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
break
|
||||
}
|
||||
|
||||
// Parse JSON to look for usage data
|
||||
if gjson.ValidBytes(data) {
|
||||
rec.parseAndRecordMetrics(gjson.ParseBytes(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rec *MetricsRecorder) processNonStreamingResponse(body []byte) {
|
||||
if len(body) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Parse JSON to extract usage information
|
||||
if gjson.ValidBytes(body) {
|
||||
rec.parseAndRecordMetrics(gjson.ParseBytes(body))
|
||||
}
|
||||
}
|
||||
|
||||
// MetricsResponseWriter captures the entire response for non-streaming
|
||||
type MetricsResponseWriter struct {
|
||||
gin.ResponseWriter
|
||||
body []byte
|
||||
metricsRecorder *MetricsRecorder
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) Write(b []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(b)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
w.body = append(w.body, b...)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) WriteHeader(statusCode int) {
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func (w *MetricsResponseWriter) Header() http.Header {
|
||||
return w.ResponseWriter.Header()
|
||||
}
|
||||
82
proxy/metrics_monitor.go
Normal file
82
proxy/metrics_monitor.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/mostlygeek/llama-swap/event"
|
||||
)
|
||||
|
||||
// TokenMetrics represents parsed token statistics from llama-server logs
|
||||
type TokenMetrics struct {
|
||||
ID int `json:"id"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Model string `json:"model"`
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
TokensPerSecond float64 `json:"tokens_per_second"`
|
||||
DurationMs int `json:"duration_ms"`
|
||||
}
|
||||
|
||||
// TokenMetricsEvent represents a token metrics event
|
||||
type TokenMetricsEvent struct {
|
||||
Metrics TokenMetrics
|
||||
}
|
||||
|
||||
func (e TokenMetricsEvent) Type() uint32 {
|
||||
return TokenMetricsEventID // defined in events.go
|
||||
}
|
||||
|
||||
// MetricsMonitor parses llama-server output for token statistics
|
||||
type MetricsMonitor struct {
|
||||
mu sync.RWMutex
|
||||
metrics []TokenMetrics
|
||||
maxMetrics int
|
||||
nextID int
|
||||
}
|
||||
|
||||
func NewMetricsMonitor(config *Config) *MetricsMonitor {
|
||||
maxMetrics := config.MetricsMaxInMemory
|
||||
if maxMetrics <= 0 {
|
||||
maxMetrics = 1000 // Default fallback
|
||||
}
|
||||
|
||||
mp := &MetricsMonitor{
|
||||
maxMetrics: maxMetrics,
|
||||
}
|
||||
|
||||
return mp
|
||||
}
|
||||
|
||||
// addMetrics adds a new metric to the collection and publishes an event
|
||||
func (mp *MetricsMonitor) addMetrics(metric TokenMetrics) {
|
||||
mp.mu.Lock()
|
||||
defer mp.mu.Unlock()
|
||||
|
||||
metric.ID = mp.nextID
|
||||
mp.nextID++
|
||||
mp.metrics = append(mp.metrics, metric)
|
||||
if len(mp.metrics) > mp.maxMetrics {
|
||||
mp.metrics = mp.metrics[len(mp.metrics)-mp.maxMetrics:]
|
||||
}
|
||||
|
||||
event.Emit(TokenMetricsEvent{Metrics: metric})
|
||||
}
|
||||
|
||||
// GetMetrics returns a copy of the current metrics
|
||||
func (mp *MetricsMonitor) GetMetrics() []TokenMetrics {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
|
||||
result := make([]TokenMetrics, len(mp.metrics))
|
||||
copy(result, mp.metrics)
|
||||
return result
|
||||
}
|
||||
|
||||
// GetMetricsJSON returns metrics as JSON
|
||||
func (mp *MetricsMonitor) GetMetricsJSON() ([]byte, error) {
|
||||
mp.mu.RLock()
|
||||
defer mp.mu.RUnlock()
|
||||
return json.Marshal(mp.metrics)
|
||||
}
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
@@ -33,6 +32,8 @@ type ProxyManager struct {
|
||||
upstreamLogger *LogMonitor
|
||||
muxLogger *LogMonitor
|
||||
|
||||
metricsMonitor *MetricsMonitor
|
||||
|
||||
processGroups map[string]*ProcessGroup
|
||||
|
||||
// shutdown signaling
|
||||
@@ -78,6 +79,8 @@ func New(config Config) *ProxyManager {
|
||||
muxLogger: stdoutLogger,
|
||||
upstreamLogger: upstreamLogger,
|
||||
|
||||
metricsMonitor: NewMetricsMonitor(&config),
|
||||
|
||||
processGroups: make(map[string]*ProcessGroup),
|
||||
|
||||
shutdownCtx: shutdownCtx,
|
||||
@@ -149,10 +152,12 @@ func (pm *ProxyManager) setupGinEngine() {
|
||||
c.Next()
|
||||
})
|
||||
|
||||
mm := MetricsMiddleware(pm)
|
||||
|
||||
// Set up routes using the Gin engine
|
||||
pm.ginEngine.POST("/v1/chat/completions", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/chat/completions", mm, pm.proxyOAIHandler)
|
||||
// Support legacy /v1/completions api, see issue #12
|
||||
pm.ginEngine.POST("/v1/completions", pm.proxyOAIHandler)
|
||||
pm.ginEngine.POST("/v1/completions", mm, pm.proxyOAIHandler)
|
||||
|
||||
// Support embeddings
|
||||
pm.ginEngine.POST("/v1/embeddings", pm.proxyOAIHandler)
|
||||
@@ -360,13 +365,8 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
|
||||
if requestedModel == "" {
|
||||
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
|
||||
return
|
||||
}
|
||||
|
||||
processGroup, realModelName, err := pm.swapProcessGroup(requestedModel)
|
||||
realModelName := c.GetString("ls-real-model-name") // Should be set in MetricsMiddleware
|
||||
processGroup, _, err := pm.swapProcessGroup(realModelName)
|
||||
if err != nil {
|
||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error swapping process group: %s", err.Error()))
|
||||
return
|
||||
|
||||
@@ -24,6 +24,7 @@ func addApiHandlers(pm *ProxyManager) {
|
||||
{
|
||||
apiGroup.POST("/models/unload", pm.apiUnloadAllModels)
|
||||
apiGroup.GET("/events", pm.apiSendEvents)
|
||||
apiGroup.GET("/metrics", pm.apiGetMetrics)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,6 +86,7 @@ type messageType string
|
||||
const (
|
||||
msgTypeModelStatus messageType = "modelStatus"
|
||||
msgTypeLogData messageType = "logData"
|
||||
msgTypeMetrics messageType = "metrics"
|
||||
)
|
||||
|
||||
type messageEnvelope struct {
|
||||
@@ -130,6 +132,18 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
sendMetrics := func(metrics TokenMetrics) {
|
||||
jsonData, err := json.Marshal(metrics)
|
||||
if err == nil {
|
||||
select {
|
||||
case sendBuffer <- messageEnvelope{Type: msgTypeMetrics, Data: string(jsonData)}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send updated models list
|
||||
*/
|
||||
@@ -150,10 +164,20 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
sendLogData("upstream", data)
|
||||
})()
|
||||
|
||||
/**
|
||||
* Send Metrics data
|
||||
*/
|
||||
defer event.On(func(e TokenMetricsEvent) {
|
||||
sendMetrics(e.Metrics)
|
||||
})()
|
||||
|
||||
// send initial batch of data
|
||||
sendLogData("proxy", pm.proxyLogger.GetHistory())
|
||||
sendLogData("upstream", pm.upstreamLogger.GetHistory())
|
||||
sendModels()
|
||||
for _, metrics := range pm.metricsMonitor.GetMetrics() {
|
||||
sendMetrics(metrics)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -169,3 +193,12 @@ func (pm *ProxyManager) apiSendEvents(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) apiGetMetrics(c *gin.Context) {
|
||||
jsonData, err := pm.metricsMonitor.GetMetricsJSON()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to get metrics"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/json", jsonData)
|
||||
}
|
||||
|
||||
@@ -165,9 +165,11 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
var response map[string]string
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
results[key] = response["responseMessage"]
|
||||
result, ok := response["responseMessage"].(string)
|
||||
assert.Equal(t, ok, true)
|
||||
results[key] = result
|
||||
mu.Unlock()
|
||||
}(key)
|
||||
|
||||
@@ -644,7 +646,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var response map[string]string
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
assert.Equal(t, "81", response["h_content_length"])
|
||||
assert.Equal(t, "model1", response["responseMessage"])
|
||||
@@ -672,7 +674,7 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
var response map[string]string
|
||||
var response map[string]interface{}
|
||||
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||
|
||||
// `temperature` and `stream` are gone but model remains
|
||||
@@ -683,3 +685,69 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
||||
// assert.Equal(t, "abc", response["y_param"])
|
||||
// t.Logf("%v", response)
|
||||
}
|
||||
|
||||
func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// Make a non-streaming request
|
||||
reqBody := `{"model":"model1", "stream": false}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check that metrics were recorded
|
||||
metrics := proxy.metricsMonitor.GetMetrics()
|
||||
assert.NotEmpty(t, metrics, "metrics should be recorded for non-streaming request")
|
||||
|
||||
// Verify the last metric has the correct model
|
||||
lastMetric := metrics[len(metrics)-1]
|
||||
assert.Equal(t, "model1", lastMetric.Model)
|
||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
||||
}
|
||||
|
||||
func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
||||
config := AddDefaultGroupToConfig(Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
},
|
||||
LogLevel: "error",
|
||||
})
|
||||
|
||||
proxy := New(config)
|
||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||
|
||||
// Make a streaming request
|
||||
reqBody := `{"model":"model1", "stream": true}`
|
||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
proxy.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check that metrics were recorded
|
||||
metrics := proxy.metricsMonitor.GetMetrics()
|
||||
assert.NotEmpty(t, metrics, "metrics should be recorded for streaming request")
|
||||
|
||||
// Verify the last metric has the correct model
|
||||
lastMetric := metrics[len(metrics)-1]
|
||||
assert.Equal(t, "model1", lastMetric.Model)
|
||||
assert.Equal(t, 25, lastMetric.InputTokens, "input tokens should be 25")
|
||||
assert.Equal(t, 10, lastMetric.OutputTokens, "output tokens should be 10")
|
||||
assert.Greater(t, lastMetric.TokensPerSecond, 0.0, "tokens per second should be greater than 0")
|
||||
assert.Greater(t, lastMetric.DurationMs, 0, "duration should be greater than 0")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user