From 4af35c1c52d68d7da67d28a3e181b990efd714b5 Mon Sep 17 00:00:00 2001 From: dorsha Date: Wed, 4 Jun 2025 14:41:59 +0300 Subject: [PATCH 1/2] Better support for trusted origins --- .gitignore | 2 ++ csrf.go | 13 ++++++++++--- options.go | 2 +- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 84039fe..9b691cf 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ coverage.coverprofile +.vscode/settings.json +.history diff --git a/csrf.go b/csrf.go index 5dda254..b9a2916 100644 --- a/csrf.go +++ b/csrf.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "slices" + "strings" "github.com/gorilla/securecookie" ) @@ -105,7 +106,7 @@ type options struct { FieldName string ErrorHandler http.Handler CookieName string - TrustedOrigins []string + TrustedOrigins string } // Protect is HTTP middleware that provides Cross-Site Request Forgery @@ -276,6 +277,8 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { requestURL.Host = r.Host } + trustedOrigins := strings.Split(cs.opts.TrustedOrigins, ",") + // if we have an Origin header, check it against our allowlist origin := r.Header.Get("Origin") if origin != "" { @@ -285,7 +288,9 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { cs.opts.ErrorHandler.ServeHTTP(w, r) return } - if !sameOrigin(&requestURL, parsedOrigin) && !slices.Contains(cs.opts.TrustedOrigins, parsedOrigin.Host) { + if !sameOrigin(&requestURL, parsedOrigin) && !slices.ContainsFunc(trustedOrigins, func(trustedOrigin string) bool { + return trustedOrigin == "*" || strings.HasSuffix(parsedOrigin.Host, trustedOrigin) + }) { r = envError(r, ErrBadOrigin) cs.opts.ErrorHandler.ServeHTTP(w, r) return @@ -318,7 +323,9 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // If the request is being served via TLS and the Referer is not the // same origin, check the domain against our allowlist. We only // check when we have host information from the referer. - if referer.Host != "" && referer.Host != r.Host && !slices.Contains(cs.opts.TrustedOrigins, referer.Host) { + if referer.Host != "" && referer.Host != r.Host && !slices.ContainsFunc(trustedOrigins, func(trustedOrigin string) bool { + return trustedOrigin == "*" || strings.HasSuffix(referer.Host, trustedOrigin) + }) { r = envError(r, ErrBadReferer) cs.opts.ErrorHandler.ServeHTTP(w, r) return diff --git a/options.go b/options.go index c61d301..cfdb800 100644 --- a/options.go +++ b/options.go @@ -125,7 +125,7 @@ func CookieName(name string) Option { // from a different domain than the API server - to correctly pass a CSRF check. // // You should only provide origins you own or have full control over. -func TrustedOrigins(origins []string) Option { +func TrustedOrigins(origins string) Option { return func(cs *csrf) { cs.opts.TrustedOrigins = origins } From 05a1dc6a57f70cffdac6366932392033933bfd5d Mon Sep 17 00:00:00 2001 From: dorsha Date: Wed, 4 Jun 2025 14:43:37 +0300 Subject: [PATCH 2/2] Better support for trusted origins --- csrf.go | 8 +++----- options.go | 2 +- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/csrf.go b/csrf.go index b9a2916..c06eb9d 100644 --- a/csrf.go +++ b/csrf.go @@ -106,7 +106,7 @@ type options struct { FieldName string ErrorHandler http.Handler CookieName string - TrustedOrigins string + TrustedOrigins []string } // Protect is HTTP middleware that provides Cross-Site Request Forgery @@ -277,8 +277,6 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { requestURL.Host = r.Host } - trustedOrigins := strings.Split(cs.opts.TrustedOrigins, ",") - // if we have an Origin header, check it against our allowlist origin := r.Header.Get("Origin") if origin != "" { @@ -288,7 +286,7 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { cs.opts.ErrorHandler.ServeHTTP(w, r) return } - if !sameOrigin(&requestURL, parsedOrigin) && !slices.ContainsFunc(trustedOrigins, func(trustedOrigin string) bool { + if !sameOrigin(&requestURL, parsedOrigin) && !slices.ContainsFunc(cs.opts.TrustedOrigins, func(trustedOrigin string) bool { return trustedOrigin == "*" || strings.HasSuffix(parsedOrigin.Host, trustedOrigin) }) { r = envError(r, ErrBadOrigin) @@ -323,7 +321,7 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // If the request is being served via TLS and the Referer is not the // same origin, check the domain against our allowlist. We only // check when we have host information from the referer. - if referer.Host != "" && referer.Host != r.Host && !slices.ContainsFunc(trustedOrigins, func(trustedOrigin string) bool { + if referer.Host != "" && referer.Host != r.Host && !slices.ContainsFunc(cs.opts.TrustedOrigins, func(trustedOrigin string) bool { return trustedOrigin == "*" || strings.HasSuffix(referer.Host, trustedOrigin) }) { r = envError(r, ErrBadReferer) diff --git a/options.go b/options.go index cfdb800..c61d301 100644 --- a/options.go +++ b/options.go @@ -125,7 +125,7 @@ func CookieName(name string) Option { // from a different domain than the API server - to correctly pass a CSRF check. // // You should only provide origins you own or have full control over. -func TrustedOrigins(origins string) Option { +func TrustedOrigins(origins []string) Option { return func(cs *csrf) { cs.opts.TrustedOrigins = origins }