diff --git a/proxy/metrics_middleware.go b/proxy/metrics_middleware.go index 1718a94..ee17717 100644 --- a/proxy/metrics_middleware.go +++ b/proxy/metrics_middleware.go @@ -17,6 +17,7 @@ func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc { bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil { pm.sendErrorResponse(c, http.StatusBadRequest, "could not ready request body") + c.Abort() return } c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) @@ -24,15 +25,16 @@ func MetricsMiddleware(pm *ProxyManager) gin.HandlerFunc { requestedModel := gjson.GetBytes(bodyBytes, "model").String() if requestedModel == "" { pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") + c.Abort() return } realModelName, found := pm.config.RealModelName(requestedModel) if !found { pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel)) + c.Abort() return } - c.Set("ls-real-model-name", realModelName) writer := &MetricsResponseWriter{ ResponseWriter: c.Writer, diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index d0f7713..3751d50 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -14,6 +14,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -367,9 +368,15 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { return } - realModelName := c.GetString("ls-real-model-name") // Should be set in MetricsMiddleware - if realModelName == "" { - pm.sendErrorResponse(c, http.StatusInternalServerError, "ls-real-model-name not set") + requestedModel := gjson.GetBytes(bodyBytes, "model").String() + if requestedModel == "" { + pm.sendErrorResponse(c, http.StatusBadRequest, "missing or invalid 'model' key") + return + } + + realModelName, found := pm.config.RealModelName(requestedModel) + if !found { + pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("could not find real modelID for %s", requestedModel)) return }