Add Access-Control-Allow-Origin CORS header to /v1/models endpoint
- match behavior of llama.cpp where the Origin in request is used - add test for listModelsHandler
This commit is contained in:
@@ -98,6 +98,10 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
||||
// Set the Content-Type header to application/json
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
||||
c.Header("Access-Control-Allow-Origin", origin)
|
||||
}
|
||||
|
||||
// Encode the data as JSON and write it to the response writer
|
||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
|
||||
|
||||
@@ -2,6 +2,7 @@ package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -141,3 +142,71 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
||||
assert.Equal(t, key, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||
config := &Config{
|
||||
HealthCheckTimeout: 15,
|
||||
Models: map[string]ModelConfig{
|
||||
"model1": getTestSimpleResponderConfig("model1"),
|
||||
"model2": getTestSimpleResponderConfig("model2"),
|
||||
"model3": getTestSimpleResponderConfig("model3"),
|
||||
},
|
||||
}
|
||||
|
||||
proxy := New(config)
|
||||
|
||||
// Create a test request
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
req.Header.Add("Origin", "i-am-the-origin")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
// Call the listModelsHandler
|
||||
proxy.HandlerFunc(w, req)
|
||||
|
||||
// Check the response status code
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
// Check for Access-Control-Allow-Origin
|
||||
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
|
||||
|
||||
// Parse the JSON response
|
||||
var response struct {
|
||||
Data []map[string]interface{} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||
}
|
||||
|
||||
// Check the number of models returned
|
||||
assert.Len(t, response.Data, 3)
|
||||
|
||||
// Check the details of each model
|
||||
expectedModels := map[string]struct{}{
|
||||
"model1": {},
|
||||
"model2": {},
|
||||
"model3": {},
|
||||
}
|
||||
|
||||
for _, model := range response.Data {
|
||||
modelID, ok := model["id"].(string)
|
||||
assert.True(t, ok, "model ID should be a string")
|
||||
_, exists := expectedModels[modelID]
|
||||
assert.True(t, exists, "unexpected model ID: %s", modelID)
|
||||
delete(expectedModels, modelID)
|
||||
|
||||
object, ok := model["object"].(string)
|
||||
assert.True(t, ok, "object should be a string")
|
||||
assert.Equal(t, "model", object)
|
||||
|
||||
created, ok := model["created"].(float64)
|
||||
assert.True(t, ok, "created should be a number")
|
||||
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
|
||||
|
||||
ownedBy, ok := model["owned_by"].(string)
|
||||
assert.True(t, ok, "owned_by should be a string")
|
||||
assert.Equal(t, "llama-swap", ownedBy)
|
||||
}
|
||||
|
||||
// Ensure all expected models were returned
|
||||
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user