Add Access-Control-Allow-Origin CORS header to /v1/models endpoint

- match behavior of llama.cpp where the Origin in request is used
- add test for listModelsHandler
This commit is contained in:
Benson Wong
2024-12-03 15:53:59 -08:00
parent da2326bdc7
commit 18c134624d
2 changed files with 73 additions and 0 deletions

View File

@@ -98,6 +98,10 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
// Set the Content-Type header to application/json
c.Header("Content-Type", "application/json")
if origin := c.Request.Header.Get("Origin"); origin != "" {
c.Header("Access-Control-Allow-Origin", origin)
}
// Encode the data as JSON and write it to the response writer
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))

View File

@@ -2,6 +2,7 @@ package proxy
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
@@ -141,3 +142,71 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
assert.Equal(t, key, result)
}
}
func TestProxyManager_ListModelsHandler(t *testing.T) {
config := &Config{
HealthCheckTimeout: 15,
Models: map[string]ModelConfig{
"model1": getTestSimpleResponderConfig("model1"),
"model2": getTestSimpleResponderConfig("model2"),
"model3": getTestSimpleResponderConfig("model3"),
},
}
proxy := New(config)
// Create a test request
req := httptest.NewRequest("GET", "/v1/models", nil)
req.Header.Add("Origin", "i-am-the-origin")
w := httptest.NewRecorder()
// Call the listModelsHandler
proxy.HandlerFunc(w, req)
// Check the response status code
assert.Equal(t, http.StatusOK, w.Code)
// Check for Access-Control-Allow-Origin
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
// Parse the JSON response
var response struct {
Data []map[string]interface{} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("Failed to parse JSON response: %v", err)
}
// Check the number of models returned
assert.Len(t, response.Data, 3)
// Check the details of each model
expectedModels := map[string]struct{}{
"model1": {},
"model2": {},
"model3": {},
}
for _, model := range response.Data {
modelID, ok := model["id"].(string)
assert.True(t, ok, "model ID should be a string")
_, exists := expectedModels[modelID]
assert.True(t, exists, "unexpected model ID: %s", modelID)
delete(expectedModels, modelID)
object, ok := model["object"].(string)
assert.True(t, ok, "object should be a string")
assert.Equal(t, "model", object)
created, ok := model["created"].(float64)
assert.True(t, ok, "created should be a number")
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
ownedBy, ok := model["owned_by"].(string)
assert.True(t, ok, "owned_by should be a string")
assert.Equal(t, "llama-swap", ownedBy)
}
// Ensure all expected models were returned
assert.Empty(t, expectedModels, "not all expected models were returned")
}