Skip to content

Commit 4af35c1

Browse files
committed
Better support for trusted origins
1 parent 9dd6af1 commit 4af35c1

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
coverage.coverprofile
2+
.vscode/settings.json
3+
.history

csrf.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"net/http"
88
"net/url"
99
"slices"
10+
"strings"
1011

1112
"github.com/gorilla/securecookie"
1213
)
@@ -105,7 +106,7 @@ type options struct {
105106
FieldName string
106107
ErrorHandler http.Handler
107108
CookieName string
108-
TrustedOrigins []string
109+
TrustedOrigins string
109110
}
110111

111112
// Protect is HTTP middleware that provides Cross-Site Request Forgery
@@ -276,6 +277,8 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
276277
requestURL.Host = r.Host
277278
}
278279

280+
trustedOrigins := strings.Split(cs.opts.TrustedOrigins, ",")
281+
279282
// if we have an Origin header, check it against our allowlist
280283
origin := r.Header.Get("Origin")
281284
if origin != "" {
@@ -285,7 +288,9 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
285288
cs.opts.ErrorHandler.ServeHTTP(w, r)
286289
return
287290
}
288-
if !sameOrigin(&requestURL, parsedOrigin) && !slices.Contains(cs.opts.TrustedOrigins, parsedOrigin.Host) {
291+
if !sameOrigin(&requestURL, parsedOrigin) && !slices.ContainsFunc(trustedOrigins, func(trustedOrigin string) bool {
292+
return trustedOrigin == "*" || strings.HasSuffix(parsedOrigin.Host, trustedOrigin)
293+
}) {
289294
r = envError(r, ErrBadOrigin)
290295
cs.opts.ErrorHandler.ServeHTTP(w, r)
291296
return
@@ -318,7 +323,9 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
318323
// If the request is being served via TLS and the Referer is not the
319324
// same origin, check the domain against our allowlist. We only
320325
// check when we have host information from the referer.
321-
if referer.Host != "" && referer.Host != r.Host && !slices.Contains(cs.opts.TrustedOrigins, referer.Host) {
326+
if referer.Host != "" && referer.Host != r.Host && !slices.ContainsFunc(trustedOrigins, func(trustedOrigin string) bool {
327+
return trustedOrigin == "*" || strings.HasSuffix(referer.Host, trustedOrigin)
328+
}) {
322329
r = envError(r, ErrBadReferer)
323330
cs.opts.ErrorHandler.ServeHTTP(w, r)
324331
return

options.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func CookieName(name string) Option {
125125
// from a different domain than the API server - to correctly pass a CSRF check.
126126
//
127127
// You should only provide origins you own or have full control over.
128-
func TrustedOrigins(origins []string) Option {
128+
func TrustedOrigins(origins string) Option {
129129
return func(cs *csrf) {
130130
cs.opts.TrustedOrigins = origins
131131
}

0 commit comments

Comments
 (0)