Refactor to use httputil.ReverseProxy (#342)
* Refactor to use httputil.ReverseProxy Refactor manual HTTP proxying logic in Process.ProxyRequest to use the standard library's httputil.ReverseProxy. * Refactor TestProcess_ForceStopWithKill test Update to handle behavior with httputil.ReverseProxy. * Fix gin interface conversion panic
This commit is contained in:
committed by
GitHub
parent
caf9e98b1e
commit
d58a8b85bf
@@ -3,6 +3,7 @@ package config
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"regexp"
|
"regexp"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -342,6 +343,13 @@ func LoadConfigFromReader(r io.Reader) (Config, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate the proxy URL.
|
||||||
|
if _, err := url.Parse(modelConfig.Proxy); err != nil {
|
||||||
|
return Config{}, fmt.Errorf(
|
||||||
|
"model %s: invalid proxy URL: %w", modelId, err,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
config.Models[modelId] = modelConfig
|
config.Models[modelId] = modelConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,11 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@@ -40,9 +39,10 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Process struct {
|
type Process struct {
|
||||||
ID string
|
ID string
|
||||||
config config.ModelConfig
|
config config.ModelConfig
|
||||||
cmd *exec.Cmd
|
cmd *exec.Cmd
|
||||||
|
reverseProxy *httputil.ReverseProxy
|
||||||
|
|
||||||
// PR #155 called to cancel the upstream process
|
// PR #155 called to cancel the upstream process
|
||||||
cmdMutex sync.RWMutex
|
cmdMutex sync.RWMutex
|
||||||
@@ -85,10 +85,29 @@ func NewProcess(ID string, healthCheckTimeout int, config config.ModelConfig, pr
|
|||||||
concurrentLimit = config.ConcurrencyLimit
|
concurrentLimit = config.ConcurrencyLimit
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Setup the reverse proxy.
|
||||||
|
proxyURL, err := url.Parse(config.Proxy)
|
||||||
|
if err != nil {
|
||||||
|
proxyLogger.Errorf("<%s> invalid proxy URL %q: %v", ID, config.Proxy, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var reverseProxy *httputil.ReverseProxy
|
||||||
|
if proxyURL != nil {
|
||||||
|
reverseProxy = httputil.NewSingleHostReverseProxy(proxyURL)
|
||||||
|
reverseProxy.ModifyResponse = func(resp *http.Response) error {
|
||||||
|
// prevent nginx from buffering streaming responses (e.g., SSE)
|
||||||
|
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
||||||
|
resp.Header.Set("X-Accel-Buffering", "no")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return &Process{
|
return &Process{
|
||||||
ID: ID,
|
ID: ID,
|
||||||
config: config,
|
config: config,
|
||||||
cmd: nil,
|
cmd: nil,
|
||||||
|
reverseProxy: reverseProxy,
|
||||||
cancelUpstream: nil,
|
cancelUpstream: nil,
|
||||||
processLogger: processLogger,
|
processLogger: processLogger,
|
||||||
proxyLogger: proxyLogger,
|
proxyLogger: proxyLogger,
|
||||||
@@ -480,56 +499,10 @@ func (p *Process) ProxyRequest(w http.ResponseWriter, r *http.Request) {
|
|||||||
startDuration = time.Since(beginStartTime)
|
startDuration = time.Since(beginStartTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyTo := p.config.Proxy
|
if p.reverseProxy != nil {
|
||||||
client := &http.Client{}
|
p.reverseProxy.ServeHTTP(w, r)
|
||||||
req, err := http.NewRequestWithContext(r.Context(), r.Method, proxyTo+r.URL.String(), r.Body)
|
} else {
|
||||||
if err != nil {
|
http.Error(w, fmt.Sprintf("No reverse proxy available for %s", p.ID), http.StatusInternalServerError)
|
||||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
req.Header = r.Header.Clone()
|
|
||||||
|
|
||||||
contentLength, err := strconv.ParseInt(req.Header.Get("content-length"), 10, 64)
|
|
||||||
if err == nil {
|
|
||||||
req.ContentLength = contentLength
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := client.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
for k, vv := range resp.Header {
|
|
||||||
for _, v := range vv {
|
|
||||||
w.Header().Add(k, v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// prevent nginx from buffering streaming responses (e.g., SSE)
|
|
||||||
if strings.Contains(strings.ToLower(resp.Header.Get("Content-Type")), "text/event-stream") {
|
|
||||||
w.Header().Set("X-Accel-Buffering", "no")
|
|
||||||
}
|
|
||||||
w.WriteHeader(resp.StatusCode)
|
|
||||||
|
|
||||||
// faster than io.Copy when streaming
|
|
||||||
buf := make([]byte, 32*1024)
|
|
||||||
for {
|
|
||||||
n, err := resp.Body.Read(buf)
|
|
||||||
if n > 0 {
|
|
||||||
if _, writeErr := w.Write(buf[:n]); writeErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if flusher, ok := w.(http.Flusher); ok {
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
http.Error(w, err.Error(), http.StatusBadGateway)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
totalTime := time.Since(requestBeginTime)
|
totalTime := time.Since(requestBeginTime)
|
||||||
|
|||||||
@@ -436,7 +436,9 @@ func TestProcess_ForceStopWithKill(t *testing.T) {
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
assert.Contains(t, w.Body.String(), "wsarecv: An existing connection was forcibly closed by the remote host")
|
||||||
} else {
|
} else {
|
||||||
assert.Contains(t, w.Body.String(), "unexpected EOF")
|
// Upstream may be killed mid-response.
|
||||||
|
// Assert an incomplete or partial response.
|
||||||
|
assert.NotEqual(t, "12345", w.Body.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
close(waitChan)
|
close(waitChan)
|
||||||
|
|||||||
@@ -21,6 +21,32 @@ import (
|
|||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TestResponseRecorder adds CloseNotify to httptest.ResponseRecorder.
|
||||||
|
// "If you want to write your own tests around streams you will need a Recorder that can handle CloseNotifier."
|
||||||
|
// The tests can panic otherwise:
|
||||||
|
// panic: interface conversion: *httptest.ResponseRecorder is not http.CloseNotifier: missing method CloseNotify
|
||||||
|
// See: https://github.com/gin-gonic/gin/issues/1815
|
||||||
|
// TestResponseRecorder is taken from gin's own tests: https://github.com/gin-gonic/gin/blob/ce20f107f5dc498ec7489d7739541a25dcd48463/context_test.go#L1747-L1765
|
||||||
|
type TestResponseRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
closeChannel chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TestResponseRecorder) CloseNotify() <-chan bool {
|
||||||
|
return r.closeChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *TestResponseRecorder) closeClient() {
|
||||||
|
r.closeChannel <- true
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateTestResponseRecorder() *TestResponseRecorder {
|
||||||
|
return &TestResponseRecorder{
|
||||||
|
httptest.NewRecorder(),
|
||||||
|
make(chan bool, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
||||||
config := config.AddDefaultGroupToConfig(config.Config{
|
config := config.AddDefaultGroupToConfig(config.Config{
|
||||||
HealthCheckTimeout: 15,
|
HealthCheckTimeout: 15,
|
||||||
@@ -37,7 +63,7 @@ func TestProxyManager_SwapProcessCorrectly(t *testing.T) {
|
|||||||
for _, modelName := range []string{"model1", "model2"} {
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -74,7 +100,7 @@ func TestProxyManager_SwapMultiProcess(t *testing.T) {
|
|||||||
t.Run(requestedModel, func(t *testing.T) {
|
t.Run(requestedModel, func(t *testing.T) {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -116,7 +142,7 @@ func TestProxyManager_PersistentGroupsAreNotSwapped(t *testing.T) {
|
|||||||
for _, requestedModel := range tests {
|
for _, requestedModel := range tests {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -159,7 +185,7 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, key)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions?wait=1000ms", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
@@ -212,7 +238,7 @@ func TestProxyManager_ListModelsHandler(t *testing.T) {
|
|||||||
// Create a test request
|
// Create a test request
|
||||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
req.Header.Add("Origin", "i-am-the-origin")
|
req.Header.Add("Origin", "i-am-the-origin")
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
// Call the listModelsHandler
|
// Call the listModelsHandler
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
@@ -311,7 +337,7 @@ models:
|
|||||||
proxy := New(processedConfig)
|
proxy := New(processedConfig)
|
||||||
|
|
||||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -387,7 +413,7 @@ func TestProxyManager_ListModelsHandler_SortedByID(t *testing.T) {
|
|||||||
|
|
||||||
// Request models list
|
// Request models list
|
||||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -448,7 +474,7 @@ func TestProxyManager_Shutdown(t *testing.T) {
|
|||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
// send a request to trigger the proxy to load ... this should hang waiting for start up
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
@@ -476,12 +502,12 @@ func TestProxyManager_Unload(t *testing.T) {
|
|||||||
proxy := New(conf)
|
proxy := New(conf)
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, "model1")
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
assert.Equal(t, proxy.processGroups[config.DEFAULT_GROUP_ID].processes["model1"].CurrentState(), StateReady)
|
||||||
req = httptest.NewRequest("GET", "/unload", nil)
|
req = httptest.NewRequest("GET", "/unload", nil)
|
||||||
w = httptest.NewRecorder()
|
w = CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
assert.Equal(t, w.Body.String(), "OK")
|
assert.Equal(t, w.Body.String(), "OK")
|
||||||
@@ -519,7 +545,7 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
|||||||
for _, modelName := range []string{"model1", "model2"} {
|
for _, modelName := range []string{"model1", "model2"} {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, modelName)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -527,7 +553,7 @@ func TestProxyManager_UnloadSingleModel(t *testing.T) {
|
|||||||
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
|
assert.Equal(t, StateReady, proxy.processGroups[testGroupId].processes["model2"].CurrentState())
|
||||||
|
|
||||||
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
|
req := httptest.NewRequest("POST", "/api/models/unload/model1", nil)
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
if !assert.Equal(t, w.Body.String(), "OK") {
|
if !assert.Equal(t, w.Body.String(), "OK") {
|
||||||
@@ -571,7 +597,7 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("no models loaded", func(t *testing.T) {
|
t.Run("no models loaded", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/running", nil)
|
req := httptest.NewRequest("GET", "/running", nil)
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -589,13 +615,13 @@ func TestProxyManager_RunningEndpoint(t *testing.T) {
|
|||||||
// Load just a model.
|
// Load just a model.
|
||||||
reqBody := `{"model":"model1"}`
|
reqBody := `{"model":"model1"}`
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
// Simulate browser call for the `/running` endpoint.
|
// Simulate browser call for the `/running` endpoint.
|
||||||
req = httptest.NewRequest("GET", "/running", nil)
|
req = httptest.NewRequest("GET", "/running", nil)
|
||||||
w = httptest.NewRecorder()
|
w = CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
var response RunningResponse
|
var response RunningResponse
|
||||||
@@ -647,7 +673,7 @@ func TestProxyManager_AudioTranscriptionHandler(t *testing.T) {
|
|||||||
// Create the request with the multipart form data
|
// Create the request with the multipart form data
|
||||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
|
|
||||||
// Verify the response
|
// Verify the response
|
||||||
@@ -682,7 +708,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
t.Run("useModelName over rides requested model: /v1/chat/completions", func(t *testing.T) {
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
reqBody := fmt.Sprintf(`{"model":"%s"}`, requestedModel)
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -716,7 +742,7 @@ func TestProxyManager_UseModelName(t *testing.T) {
|
|||||||
// Create the request with the multipart form data
|
// Create the request with the multipart form data
|
||||||
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
req := httptest.NewRequest("POST", "/v1/audio/transcriptions", &b)
|
||||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
|
|
||||||
// Verify the response
|
// Verify the response
|
||||||
@@ -784,7 +810,7 @@ func TestProxyManager_CORSOptionsHandler(t *testing.T) {
|
|||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
|
|
||||||
assert.Equal(t, tt.expectedStatus, w.Code)
|
assert.Equal(t, tt.expectedStatus, w.Code)
|
||||||
@@ -812,7 +838,7 @@ models:
|
|||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
t.Run("main model name", func(t *testing.T) {
|
t.Run("main model name", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
req := httptest.NewRequest("GET", "/upstream/model1/test", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "model1", rec.Body.String())
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
@@ -820,7 +846,7 @@ models:
|
|||||||
|
|
||||||
t.Run("model alias", func(t *testing.T) {
|
t.Run("model alias", func(t *testing.T) {
|
||||||
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
|
req := httptest.NewRequest("GET", "/upstream/model-alias/test", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "model1", rec.Body.String())
|
assert.Equal(t, "model1", rec.Body.String())
|
||||||
@@ -841,7 +867,7 @@ func TestProxyManager_ChatContentLength(t *testing.T) {
|
|||||||
|
|
||||||
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
reqBody := fmt.Sprintf(`{"model":"%s", "x": "this is just some content to push the length out a bit"}`, "model1")
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -869,7 +895,7 @@ func TestProxyManager_FiltersStripParams(t *testing.T) {
|
|||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
|
reqBody := `{"model":"model1", "temperature":0.1, "x_param":"123", "y_param":"abc", "stream":true}`
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -900,7 +926,7 @@ func TestProxyManager_MiddlewareWritesMetrics_NonStreaming(t *testing.T) {
|
|||||||
// Make a non-streaming request
|
// Make a non-streaming request
|
||||||
reqBody := `{"model":"model1", "stream": false}`
|
reqBody := `{"model":"model1", "stream": false}`
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -935,7 +961,7 @@ func TestProxyManager_MiddlewareWritesMetrics_Streaming(t *testing.T) {
|
|||||||
// Make a streaming request
|
// Make a streaming request
|
||||||
reqBody := `{"model":"model1", "stream": true}`
|
reqBody := `{"model":"model1", "stream": true}`
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -967,7 +993,7 @@ func TestProxyManager_HealthEndpoint(t *testing.T) {
|
|||||||
proxy := New(config)
|
proxy := New(config)
|
||||||
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
defer proxy.StopProcesses(StopWaitForInflightRequest)
|
||||||
req := httptest.NewRequest("GET", "/health", nil)
|
req := httptest.NewRequest("GET", "/health", nil)
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
assert.Equal(t, http.StatusOK, rec.Code)
|
assert.Equal(t, http.StatusOK, rec.Code)
|
||||||
assert.Equal(t, "OK", rec.Body.String())
|
assert.Equal(t, "OK", rec.Body.String())
|
||||||
@@ -988,7 +1014,7 @@ func TestProxyManager_CompletionEndpoint(t *testing.T) {
|
|||||||
|
|
||||||
reqBody := `{"model":"model1"}`
|
reqBody := `{"model":"model1"}`
|
||||||
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/completion", bytes.NewBufferString(reqBody))
|
||||||
w := httptest.NewRecorder()
|
w := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(w, req)
|
proxy.ServeHTTP(w, req)
|
||||||
assert.Equal(t, http.StatusOK, w.Code)
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
@@ -1080,7 +1106,7 @@ func TestProxyManager_StreamingEndpointsReturnNoBufferingHeader(t *testing.T) {
|
|||||||
|
|
||||||
req := httptest.NewRequest("GET", endpoint, nil)
|
req := httptest.NewRequest("GET", endpoint, nil)
|
||||||
req = req.WithContext(ctx)
|
req = req.WithContext(ctx)
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
|
|
||||||
// Run handler in goroutine and wait for context timeout
|
// Run handler in goroutine and wait for context timeout
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
@@ -1119,7 +1145,7 @@ func TestProxyManager_ProxiedStreamingEndpointReturnsNoBufferingHeader(t *testin
|
|||||||
reqBody := `{"model":"streaming-model"}`
|
reqBody := `{"model":"streaming-model"}`
|
||||||
// simple-responder will return text/event-stream when stream=true is in the query
|
// simple-responder will return text/event-stream when stream=true is in the query
|
||||||
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
req := httptest.NewRequest("POST", "/v1/chat/completions?stream=true", bytes.NewBufferString(reqBody))
|
||||||
rec := httptest.NewRecorder()
|
rec := CreateTestResponseRecorder()
|
||||||
|
|
||||||
proxy.ServeHTTP(rec, req)
|
proxy.ServeHTTP(rec, req)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user