diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 8a22d07..09e09d1 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -72,13 +72,26 @@ func New(config *Config) *ProxyManager { }) } - // see: https://github.com/mostlygeek/llama-swap/issues/42 + // see: issue: #81, #77 and #42 for CORS issues // respond with permissive OPTIONS for any endpoint pm.ginEngine.Use(func(c *gin.Context) { + + // set this for all requests + c.Header("Access-Control-Allow-Origin", "*") + if c.Request.Method == "OPTIONS" { - c.Header("Access-Control-Allow-Origin", "*") - c.Header("Access-Control-Allow-Methods", "*") - c.Header("Access-Control-Allow-Headers", "*") + c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS") + + // allow whatever the client requested by default + if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" { + c.Header("Access-Control-Allow-Headers", headers) + } else { + c.Header( + "Access-Control-Allow-Headers", + "Content-Type, Authorization, Accept, X-Requested-With", + ) + } + c.Header("Access-Control-Max-Age", "86400") c.AbortWithStatus(http.StatusNoContent) return } diff --git a/proxy/proxymanager_test.go b/proxy/proxymanager_test.go index a8c1576..ba7aefe 100644 --- a/proxy/proxymanager_test.go +++ b/proxy/proxymanager_test.go @@ -639,5 +639,95 @@ func TestProxyManager_UseModelName(t *testing.T) { assert.Equal(t, upstreamModelName, response["model"]) }) } - +} + +func TestProxyManager_CORSOptionsHandler(t *testing.T) { + config := &Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + }, + LogRequests: true, + } + + tests := []struct { + name string + method string + requestHeaders map[string]string + expectedStatus int + expectedHeaders map[string]string + }{ + { + name: "OPTIONS with no headers", + method: "OPTIONS", + expectedStatus: http.StatusNoContent, + expectedHeaders: map[string]string{ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS", + "Access-Control-Allow-Headers": "Content-Type, Authorization, Accept, X-Requested-With", + }, + }, + { + name: "OPTIONS with specific headers", + method: "OPTIONS", + requestHeaders: map[string]string{ + "Access-Control-Request-Headers": "X-Custom-Header, Some-Other-Header", + }, + expectedStatus: http.StatusNoContent, + expectedHeaders: map[string]string{ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, PUT, PATCH, DELETE, OPTIONS", + "Access-Control-Allow-Headers": "X-Custom-Header, Some-Other-Header", + }, + }, + { + name: "Non-OPTIONS request", + method: "GET", + expectedStatus: http.StatusNotFound, // Since we don't have a GET route defined + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + proxy := New(config) + defer proxy.StopProcesses() + + req := httptest.NewRequest(tt.method, "/v1/chat/completions", nil) + for k, v := range tt.requestHeaders { + req.Header.Set(k, v) + } + + w := httptest.NewRecorder() + proxy.ginEngine.ServeHTTP(w, req) + + assert.Equal(t, tt.expectedStatus, w.Code) + + for header, expectedValue := range tt.expectedHeaders { + assert.Equal(t, expectedValue, w.Header().Get(header)) + } + }) + } +} + +func TestProxyManager_CORSHeadersInRegularRequest(t *testing.T) { + config := &Config{ + HealthCheckTimeout: 15, + Models: map[string]ModelConfig{ + "model1": getTestSimpleResponderConfig("model1"), + }, + LogRequests: true, + } + + proxy := New(config) + defer proxy.StopProcesses() + + // Test that CORS headers are present in regular POST requests + reqBody := `{"model":"model1"}` + req := httptest.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody)) + w := httptest.NewRecorder() + + proxy.ginEngine.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin")) }