Improve content-length handling (#115)

ref: See #114

* Improve content-length handling
- Content length was not always being sent
- Add tests for content-length
This commit is contained in:
Benson Wong
2025-05-05 10:46:26 -07:00
committed by GitHub
parent cc450e9c5f
commit 21d7973d11
3 changed files with 52 additions and 11 deletions

View File

@@ -371,7 +371,6 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
// dechunk it as we already have all the body bytes see issue #11
c.Request.Header.Del("transfer-encoding")
c.Request.Header.Del("content-length")
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
@@ -382,11 +381,6 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
}
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Parse multipart form
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
@@ -406,6 +400,11 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
return
}
// We need to reconstruct the multipart form in any case since the body is consumed
// Create a new buffer for the reconstructed request
var requestBuffer bytes.Buffer
multipartWriter := multipart.NewWriter(&requestBuffer)
// Copy all form values
for key, values := range c.Request.MultipartForm.Value {
for _, value := range values {
@@ -479,6 +478,10 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
modifiedReq.Header = c.Request.Header.Clone()
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
// set the content length of the body
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
modifiedReq.ContentLength = int64(requestBuffer.Len())
// Use the modified request for proxying
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))

View File

@@ -8,6 +8,7 @@ import (
"mime/multipart"
"net/http"
"net/http/httptest"
"strconv"
"sync"
"testing"
"time"
@@ -165,7 +166,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
mu.Lock()
results[key] = w.Body.String()
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
results[key] = response["responseMessage"]
mu.Unlock()
}(key)
@@ -442,6 +445,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "TheExpectedModel", response["model"])
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
}
// Test useModelName in configuration sends overrides what is sent to upstream
@@ -592,3 +596,27 @@ func TestProxyManager_Upstream(t *testing.T) {
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "model1", rec.Body.String())
}
func TestProxyManager_ChatContentLength(t *testing.T) {
config := AddDefaultGroupToConfig(Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
},
LogLevel: "error",
})
proxy := New(config)
defer proxy.StopProcesses()
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
w := httptest.NewRecorder()
proxy.HandlerFunc(w, req)
assert.Equal(t, http.StatusOK, w.Code)
var response map[string]string
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
assert.Equal(t, "81", response["h_content_length"])
assert.Equal(t, "model1", response["responseMessage"])
}