Improve content-length handling (#115)
ref: See #114 * Improve content-length handling - Content length was not always being sent - Add tests for content-length
This commit is contained in:
@@ -33,14 +33,17 @@ func main() {
|
|||||||
|
|
||||||
// Set up the handler function using the provided response message
|
// Set up the handler function using the provided response message
|
||||||
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
r.POST("/v1/chat/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
// add a wait to simulate a slow query
|
// add a wait to simulate a slow query
|
||||||
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
if wait, err := time.ParseDuration(c.Query("wait")); err == nil {
|
||||||
time.Sleep(wait)
|
time.Sleep(wait)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.String(200, *responseMessage)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
"h_content_length": c.Request.Header.Get("Content-Length"),
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
// for issue #62 to check model name strips profile slug
|
// for issue #62 to check model name strips profile slug
|
||||||
@@ -63,8 +66,11 @@ func main() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
r.POST("/v1/completions", func(c *gin.Context) {
|
r.POST("/v1/completions", func(c *gin.Context) {
|
||||||
c.Header("Content-Type", "text/plain")
|
c.Header("Content-Type", "application/json")
|
||||||
c.String(200, *responseMessage)
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"responseMessage": *responseMessage,
|
||||||
|
})
|
||||||
|
|
||||||
})
|
})
|
||||||
|
|
||||||
// issue #41
|
// issue #41
|
||||||
@@ -104,6 +110,10 @@ func main() {
|
|||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
"text": fmt.Sprintf("The length of the file is %d bytes", fileSize),
|
||||||
"model": model,
|
"model": model,
|
||||||
|
|
||||||
|
// expose some header values for testing
|
||||||
|
"h_content_type": c.GetHeader("Content-Type"),
|
||||||
|
"h_content_length": c.GetHeader("Content-Length"),
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -371,7 +371,6 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
|
|
||||||
// dechunk it as we already have all the body bytes see issue #11
|
// dechunk it as we already have all the body bytes see issue #11
|
||||||
c.Request.Header.Del("transfer-encoding")
|
c.Request.Header.Del("transfer-encoding")
|
||||||
c.Request.Header.Del("content-length")
|
|
||||||
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
c.Request.Header.Add("content-length", strconv.Itoa(len(bodyBytes)))
|
||||||
|
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, c.Request); err != nil {
|
||||||
@@ -382,11 +381,6 @@ func (pm *ProxyManager) proxyOAIHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
||||||
// We need to reconstruct the multipart form in any case since the body is consumed
|
|
||||||
// Create a new buffer for the reconstructed request
|
|
||||||
var requestBuffer bytes.Buffer
|
|
||||||
multipartWriter := multipart.NewWriter(&requestBuffer)
|
|
||||||
|
|
||||||
// Parse multipart form
|
// Parse multipart form
|
||||||
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory, larger files go to tmp disk
|
||||||
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusBadRequest, fmt.Sprintf("error parsing multipart form: %s", err.Error()))
|
||||||
@@ -406,6 +400,11 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We need to reconstruct the multipart form in any case since the body is consumed
|
||||||
|
// Create a new buffer for the reconstructed request
|
||||||
|
var requestBuffer bytes.Buffer
|
||||||
|
multipartWriter := multipart.NewWriter(&requestBuffer)
|
||||||
|
|
||||||
// Copy all form values
|
// Copy all form values
|
||||||
for key, values := range c.Request.MultipartForm.Value {
|
for key, values := range c.Request.MultipartForm.Value {
|
||||||
for _, value := range values {
|
for _, value := range values {
|
||||||
@@ -479,6 +478,10 @@ func (pm *ProxyManager) proxyOAIPostFormHandler(c *gin.Context) {
|
|||||||
modifiedReq.Header = c.Request.Header.Clone()
|
modifiedReq.Header = c.Request.Header.Clone()
|
||||||
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
modifiedReq.Header.Set("Content-Type", multipartWriter.FormDataContentType())
|
||||||
|
|
||||||
|
// set the content length of the body
|
||||||
|
modifiedReq.Header.Set("Content-Length", strconv.Itoa(requestBuffer.Len()))
|
||||||
|
modifiedReq.ContentLength = int64(requestBuffer.Len())
|
||||||
|
|
||||||
// Use the modified request for proxying
|
// Use the modified request for proxying
|
||||||
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
if err := processGroup.ProxyRequest(realModelName, c.Writer, modifiedReq); err != nil {
|
||||||
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
pm.sendErrorResponse(c, http.StatusInternalServerError, fmt.Sprintf("error proxying request: %s", err.Error()))
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -165,7 +166,9 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
|
|
||||||
results[key] = w.Body.String()
|
var response map[string]string
|
||||||
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
|
results[key] = response["responseMessage"]
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
}(key)
|
}(key)
|
||||||
|
|
||||||
@@ -442,6 +445,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, "TheExpectedModel", response["model"])
|
assert.Equal(t, "TheExpectedModel", response["model"])
|
||||||
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
assert.Equal(t, response["text"], fmt.Sprintf("The length of the file is %d bytes", contentLength)) // matches simple-responder
|
||||||
|
assert.Equal(t, strconv.Itoa(370+contentLength), response["h_content_length"])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test useModelName in configuration sends overrides what is sent to upstream
|
// Test useModelName in configuration sends overrides what is sent to upstream
|
||||||
@@ -592,3 +596,27 @@ func TestProxyManager_Upstream(t *testing.T) {
|
|||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "model1", rec.Body.String())
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ChatContentLength(t *testing.T) {
|
||||||
|
config := AddDefaultGroupToConfig(Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
},
|
||||||
|
LogLevel: "error",
|
||||||
|
})
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
defer proxy.StopProcesses()
|
||||||
|
|
||||||
|
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||||
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
var response map[string]string
|
||||||
|
assert.NoError(t, json.Unmarshal(w.Body.Bytes(), &response))
|
||||||
|
assert.Equal(t, "81", response["h_content_length"])
|
||||||
|
assert.Equal(t, "model1", response["responseMessage"])
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user