From 21d7973d11f6553b9beb24e7541fcc29654e80f5 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Mon, 5 May 2025 10:46:26 -0700 Subject: [PATCH] Improve content-length handling (#115) ref: See #114 * Improve content-length handling - Content length was not always being sent - Add tests for content-length --- misc/simple-responder/simple-responder.go | 18 +++++++++++--- proxy/proxymanager.go | 15 +++++++----- proxy/proxymanager_test.go | 30 ++++++++++++++++++++++- 3 files changed, 52 insertions(+), 11 deletions(-) diff --git a/misc/simple-responder/simple-responder.go b/misc/simple-responder/simple-responder.go index fcf8354..0512972 100644 --- a/misc/simple-responder/simple-responder.go +++ b/misc/simple-responder/simple-responder.go @@ -33,14 +33,17 @@ func main() { // Set up the handler function using the provided response message r.POST("/v1/chat/completions", func(c *gin.Context) { - c.Header("Content-Type", "text/plain") + c.Header("Content-Type", "application/json") // add a wait to simulate a slow query if wait, err := time.ParseDuration(c.Query("wait")); err == nil { time.Sleep(wait) } - c.String(200, *responseMessage) + c.JSON(http.StatusOK, gin.H{ + "responseMessage": *responseMessage, + "h_content_length": c.Request.Header.Get("Content-Length"), + }) }) // for issue #62 to check model name strips profile slug @@ -63,8 +66,11 @@ func main() { }) r.POST("/v1/completions", func(c *gin.Context) { - c.Header("Content-Type", "text/plain") - c.String(200, *responseMessage) + c.Header("Content-Type", "application/json") + c.JSON(http.StatusOK, gin.H{ + "responseMessage": *responseMessage, + }) + }) // issue #41 @@ -104,6 +110,10 @@ func main() { c.JSON(http.StatusOK, gin.H{ "text": fmt.Sprintf("The length of the file is %d bytes", fileSize), "model": model, + + // expose some header values for testing + "h_content_type": c.GetHeader("Content-Type"), + "h_content_length": c.GetHeader("Content-Length"), }) }) diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 9763722..f1c5661 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -371,7 +371,6 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { // dechunk it as we already have all the body bytes see issue #11 c.Request.Header.Del("transfer-encoding") - c.Request.Header.Del("content-length") c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes))) if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil { @@ -382,11 +381,6 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) { } func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { - // We need to reconstruct the multipart form in any case since the body is consumed - // Create a new buffer for the reconstructed request - var requestBuffer bytes.Buffer - multipartWriter := multipart.NewWriter(&requestBuffer) - // Parse multipart form if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error())) @@ -406,6 +400,11 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { return } + // We need to reconstruct the multipart form in any case since the body is consumed + // Create a new buffer for the reconstructed request + var requestBuffer bytes.Buffer + multipartWriter := multipart.NewWriter(&requestBuffer) + // Copy all form values for key, values := range c.Request.MultipartForm.Value { for _, value := range values { @@ -479,6 +478,10 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) { modifiedReq.Header = c.Request.Header.Clone() modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType()) + // set the content length of the body + modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len())) + modifiedReq.ContentLength = int64(requestBuffer.Len()) + // Use the modified request for proxying if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil { pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error())) diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index cddfa40..83cbd09 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -8,6 +8,7 @@ import ( "mime/multipart" "net/http" "net/http/httptest" + "strconv" "sync" "testing" "time" @@ -165,7 +166,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { mu.Lock() - results[key] = w.Body.String() + var response map[string]string + assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) + results[key] = response["responseMessage"] mu.Unlock() }(key) @@ -442,6 +445,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "TheExpectedModel", response["model"]) assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder + assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"]) } // Test useModelName in configuration sends overrides what is sent to upstream @@ -592,3 +596,27 @@ func TestProxyManager_Upstream(t *testing.T) { assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "model1", rec.Body.String()) } + +func TestProxyManager_ChatContentLength(t *testing.T) { + config := AddDefaultGroupToConfig(Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + }, + LogLevel: "error", + }) + + proxy := New(config) + defer proxy.StopProcesses() + + reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1") + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + proxy.HandlerFunc(w, req) + assert.Equal(t, http.StatusOK, w.Code) + var response map[string]string + assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response)) + assert.Equal(t, "81", response["h_content_length"]) + assert.Equal(t, "model1", response["responseMessage"]) +}