diff --git a/.gitignore b/.gitignore index 84039fe..9b691cf 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ coverage.coverprofile +.vscode/settings.json +.history diff --git a/csrf.go b/csrf.go index 5dda254..c06eb9d 100644 --- a/csrf.go +++ b/csrf.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "slices" + "strings" "github.com/gorilla/securecookie" ) @@ -285,7 +286,9 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { cs.opts.ErrorHandler.ServeHTTP(w, r) return } - if !sameOrigin(&requestURL, parsedOrigin) && !slices.Contains(cs.opts.TrustedOrigins, parsedOrigin.Host) { + if !sameOrigin(&requestURL, parsedOrigin) && !slices.ContainsFunc(cs.opts.TrustedOrigins, func(trustedOrigin string) bool { + return trustedOrigin == "*" || strings.HasSuffix(parsedOrigin.Host, trustedOrigin) + }) { r = envError(r, ErrBadOrigin) cs.opts.ErrorHandler.ServeHTTP(w, r) return @@ -318,7 +321,9 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) { // If the request is being served via TLS and the Referer is not the // same origin, check the domain against our allowlist. We only // check when we have host information from the referer. - if referer.Host != "" && referer.Host != r.Host && !slices.Contains(cs.opts.TrustedOrigins, referer.Host) { + if referer.Host != "" && referer.Host != r.Host && !slices.ContainsFunc(cs.opts.TrustedOrigins, func(trustedOrigin string) bool { + return trustedOrigin == "*" || strings.HasSuffix(referer.Host, trustedOrigin) + }) { r = envError(r, ErrBadReferer) cs.opts.ErrorHandler.ServeHTTP(w, r) return