From 18c134624df8a6dbe92ac33d9a967e2932f15848 Mon Sep 17 00:00:00 2001 From: Benson Wong Date: Tue, 3 Dec 2024 15:53:59 -0800 Subject: [PATCH] Add Access-Control-Allow-Origin CORS header to /v1/models endpoint - match behavior of llama.cpp where the Origin in request is used - add test for listModelsHandler --- proxy/proxymanager.go | 4 +++ proxy/proxymanager_test.go | 69 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 7994e97..0c26990 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -98,6 +98,10 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) { // Set the Content-Type header to application/json c.Header("Content-Type", "application/json") + if origin := c.Request.Header.Get("Origin"); origin != "" { + c.Header("Access-Control-Allow-Origin", origin) + } + // Encode the data as JSON and write it to the response writer if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil { c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON")) diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index 7b8110c..1d96c14 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -2,6 +2,7 @@ package proxy import ( "bytes" + "encoding/json" "fmt" "net/http" "net/http/httptest" @@ -141,3 +142,71 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) { assert.Equal(t, key, result) } } + +func TestProxyManager_ListModelsHandler(t *testing.T) { + config := &Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + "model2": getTestSimpleResponderConfig("model2"), + "model3": getTestSimpleResponderConfig("model3"), + }, + } + + proxy := New(config) + + // Create a test request + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Add("Origin", "i-am-the-origin") + w := httptest.NewRecorder() + + // Call the listModelsHandler + proxy.HandlerFunc(w, req) + + // Check the response status code + assert.Equal(t, http.StatusOK, w.Code) + + // Check for Access-Control-Allow-Origin + assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin")) + + // Parse the JSON response + var response struct { + Data []map[string]interface{} `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil { + t.Fatalf("Failed to parse JSON response: %v", err) + } + + // Check the number of models returned + assert.Len(t, response.Data, 3) + + // Check the details of each model + expectedModels := map[string]struct{}{ + "model1": {}, + "model2": {}, + "model3": {}, + } + + for _, model := range response.Data { + modelID, ok := model["id"].(string) + assert.True(t, ok, "model ID should be a string") + _, exists := expectedModels[modelID] + assert.True(t, exists, "unexpected model ID: %s", modelID) + delete(expectedModels, modelID) + + object, ok := model["object"].(string) + assert.True(t, ok, "object should be a string") + assert.Equal(t, "model", object) + + created, ok := model["created"].(float64) + assert.True(t, ok, "created should be a number") + assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive + + ownedBy, ok := model["owned_by"].(string) + assert.True(t, ok, "owned_by should be a string") + assert.Equal(t, "llama-swap", ownedBy) + } + + // Ensure all expected models were returned + assert.Empty(t, expectedModels, "not all expected models were returned") +}