Decouple MetricsMiddleware from downstream handlers (#206)

* Decouple MetricsMiddleware from downstream handlers

Remove ls-real-model-name optimization. Within proxyOAIHandler the
request body's bytes are required for various rewriting features
anyways. This negated any benefits from trying not to parse it twice.
This commit is contained in:
Benson Wong
2025-07-27 10:36:06 -07:00
committed by GitHub
parent 8c693e7fcf
commit fd50932dbc
2 changed files with 13 additions and 4 deletions

View File

@@ -17,6 +17,7 @@ func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body")
c.Abort()
return
}
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
@@ -24,15 +25,16 @@ func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc {
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
c.Abort()
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
c.Abort()
return
}
c.Set("ls-real-model-name", realModelName)
writer := &MetricsResponseWriter{
ResponseWriter: c.Writer,

View File

@@ -14,6 +14,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -367,9 +368,15 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
return
}
realModelName := c.GetString("ls-real-model-name") // Should be set in MetricsMiddleware
if realModelName == "" {
pm.sendErrorResponse(c, http.StatusInternalServerError, "ls-real-model-name not set")
requestedModel := gjson.GetBytes(bodyBytes, "model").String()
if requestedModel == "" {
pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key")
return
}
realModelName, found := pm.config.RealModelName(requestedModel)
if !found {
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel))
return
}