diff --git a/proxy/proxymanager.go b/proxy/proxymanager.go index 09e09d1..4c4fee5 100644 --- a/proxy/proxymanager.go +++ b/proxy/proxymanager.go @@ -84,7 +84,8 @@ func New(config *Config) *ProxyManager { // allow whatever the client requested by default if headers := c.Request.Header.Get("Access-Control-Request-Headers"); headers != "" { - c.Header("Access-Control-Allow-Headers", headers) + sanitized := SanitizeAccessControlRequestHeaderValues(headers) + c.Header("Access-Control-Allow-Headers", sanitized) } else { c.Header( "Access-Control-Allow-Headers", diff --git a/proxy/sanitize_cors.go b/proxy/sanitize_cors.go new file mode 100644 index 0000000..70873fa --- /dev/null +++ b/proxy/sanitize_cors.go @@ -0,0 +1,43 @@ +package proxy + +import ( + "strings" +) + +func isTokenChar(r rune) bool { + switch { + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r >= '0' && r <= '9': + case strings.ContainsRune("!#$%&'*+-.^_`|~", r): + default: + return false + } + return true +} + +func SanitizeAccessControlRequestHeaderValues(headerValues string) string { + parts := strings.Split(headerValues, ",") + valid := make([]string, 0, len(parts)) + + for _, p := range parts { + v := strings.TrimSpace(p) + if v == "" { + continue + } + + validPart := true + for _, c := range v { + if !isTokenChar(c) { + validPart = false + break + } + } + + if validPart { + valid = append(valid, v) + } + } + + return strings.Join(valid, ", ") +} diff --git a/proxy/sanitize_cors_test.go b/proxy/sanitize_cors_test.go new file mode 100644 index 0000000..ad11fcc --- /dev/null +++ b/proxy/sanitize_cors_test.go @@ -0,0 +1,77 @@ +package proxy + +import "testing" + +func TestSanitizeAccessControlRequestHeaderValues(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "empty string", + input: "", + expected: "", + }, + { + name: "whitespace only", + input: " ", + expected: "", + }, + { + name: "single valid value", + input: "content-type", + expected: "content-type", + }, + { + name: "multiple valid values", + input: "content-type, authorization, x-requested-with", + expected: "content-type, authorization, x-requested-with", + }, + { + name: "values with extra spaces", + input: " content-type , authorization ", + expected: "content-type, authorization", + }, + { + name: "values with tabs", + input: "content-type,\tauthorization", + expected: "content-type, authorization", + }, + { + name: "values with invalid characters", + input: "content-type, auth\n, x-requested-with\r", + expected: "content-type, auth, x-requested-with", + }, + { + name: "empty values in list", + input: "content-type,,authorization", + expected: "content-type, authorization", + }, + { + name: "leading and trailing commas", + input: ",content-type,authorization,", + expected: "content-type, authorization", + }, + { + name: "mixed valid and invalid values", + input: "content-type, \x00invalid, x-requested-with", + expected: "content-type, x-requested-with", + }, + { + name: "mixed case values", + input: "Content-Type, my-Valid-Header, Another-hEader", + expected: "Content-Type, my-Valid-Header, Another-hEader", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SanitizeAccessControlRequestHeaderValues(tt.input) + if got != tt.expected { + t.Errorf("SanitizeAccessControlRequestHeaderValues(%q) = %q, want %q", + tt.input, got, tt.expected) + } + }) + } +}