Add Access-Control-Allow-Origin CORS header to /v1/models endpoint
- match behavior of llama.cpp where the Origin in request is used - add test for listModelsHandler
This commit is contained in:
@@ -98,6 +98,10 @@ func (pm *ProxyManager) listModelsHandler(c *gin.Context) {
|
|||||||
// Set the Content-Type header to application/json
|
// Set the Content-Type header to application/json
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
if origin := c.Request.Header.Get("Origin"); origin != "" {
|
||||||
|
c.Header("Access-Control-Allow-Origin", origin)
|
||||||
|
}
|
||||||
|
|
||||||
// Encode the data as JSON and write it to the response writer
|
// Encode the data as JSON and write it to the response writer
|
||||||
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
if err := json.NewEncoder(c.Writer).Encode(map[string]interface{}{"data": data}); err != nil {
|
||||||
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
|
c.AbortWithError(http.StatusInternalServerError, fmt.Errorf("error encoding JSON"))
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
@@ -141,3 +142,71 @@ func TestProxyManager_SwapMultiProcessParallelRequests(t *testing.T) {
|
|||||||
assert.Equal(t, key, result)
|
assert.Equal(t, key, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyManager_ListModelsHandler(t *testing.T) {
|
||||||
|
config := &Config{
|
||||||
|
HealthCheckTimeout: 15,
|
||||||
|
Models: map[string]ModelConfig{
|
||||||
|
"model1": getTestSimpleResponderConfig("model1"),
|
||||||
|
"model2": getTestSimpleResponderConfig("model2"),
|
||||||
|
"model3": getTestSimpleResponderConfig("model3"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New(config)
|
||||||
|
|
||||||
|
// Create a test request
|
||||||
|
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||||
|
req.Header.Add("Origin", "i-am-the-origin")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
// Call the listModelsHandler
|
||||||
|
proxy.HandlerFunc(w, req)
|
||||||
|
|
||||||
|
// Check the response status code
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
|
||||||
|
// Check for Access-Control-Allow-Origin
|
||||||
|
assert.Equal(t, req.Header.Get("Origin"), w.Result().Header.Get("Access-Control-Allow-Origin"))
|
||||||
|
|
||||||
|
// Parse the JSON response
|
||||||
|
var response struct {
|
||||||
|
Data []map[string]interface{} `json:"data"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||||
|
t.Fatalf("Failed to parse JSON response: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the number of models returned
|
||||||
|
assert.Len(t, response.Data, 3)
|
||||||
|
|
||||||
|
// Check the details of each model
|
||||||
|
expectedModels := map[string]struct{}{
|
||||||
|
"model1": {},
|
||||||
|
"model2": {},
|
||||||
|
"model3": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, model := range response.Data {
|
||||||
|
modelID, ok := model["id"].(string)
|
||||||
|
assert.True(t, ok, "model ID should be a string")
|
||||||
|
_, exists := expectedModels[modelID]
|
||||||
|
assert.True(t, exists, "unexpected model ID: %s", modelID)
|
||||||
|
delete(expectedModels, modelID)
|
||||||
|
|
||||||
|
object, ok := model["object"].(string)
|
||||||
|
assert.True(t, ok, "object should be a string")
|
||||||
|
assert.Equal(t, "model", object)
|
||||||
|
|
||||||
|
created, ok := model["created"].(float64)
|
||||||
|
assert.True(t, ok, "created should be a number")
|
||||||
|
assert.Greater(t, created, float64(0)) // Assuming the timestamp is positive
|
||||||
|
|
||||||
|
ownedBy, ok := model["owned_by"].(string)
|
||||||
|
assert.True(t, ok, "owned_by should be a string")
|
||||||
|
assert.Equal(t, "llama-swap", ownedBy)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure all expected models were returned
|
||||||
|
assert.Empty(t, expectedModels, "not all expected models were returned")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user