Skip to content

Commit 50eb875

Browse files
committed
[feature] Use http.Request.Context() in >= Go 1.7.
2 parents 16dc2f5 + 5b56d12 commit 50eb875

File tree

5 files changed

+58
-10
lines changed

5 files changed

+58
-10
lines changed

context.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// +build go1.7
2+
3+
package csrf
4+
5+
import (
6+
"context"
7+
"net/http"
8+
9+
"github.com/pkg/errors"
10+
)
11+
12+
func contextGet(r *http.Request, key string) (interface{}, error) {
13+
val := r.Context().Value(key)
14+
if val == nil {
15+
return nil, errors.Errorf("no value exists in the context for key %q", key)
16+
}
17+
18+
return val, nil
19+
}
20+
21+
func contextSave(r *http.Request, key string, val interface{}) *http.Request {
22+
ctx := r.Context()
23+
ctx = context.WithValue(ctx, key, val)
24+
return r.WithContext(ctx)
25+
}

context_legacy.go

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// +build !go1.7
2+
3+
package csrf
4+
5+
import (
6+
"net/http"
7+
8+
"github.com/gorilla/context"
9+
10+
"github.com/pkg/errors"
11+
)
12+
13+
func contextGet(r *http.Request, key string) (interface{}, error) {
14+
if val, ok := context.GetOk(r, key); ok {
15+
return val, nil
16+
}
17+
18+
return nil, errors.Errorf("no value exists in the context for key %q", key)
19+
}
20+
21+
func contextSave(r *http.Request, key string, val interface{}) *http.Request {
22+
context.Set(r, key, val)
23+
return r
24+
}

csrf.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ func Protect(authKey []byte, opts ...Option) func(http.Handler) http.Handler {
174174
// Implements http.Handler for the csrf type.
175175
func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
176176
// Skip the check if directed to. This should always be a bool.
177-
if val, ok := context.GetOk(r, skipCheckKey); ok {
177+
if val, err := contextGet(r, skipCheckKey); err == nil {
178178
if skip, ok := val.(bool); ok {
179179
if skip {
180180
cs.h.ServeHTTP(w, r)
@@ -209,9 +209,9 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
209209
}
210210

211211
// Save the masked token to the request context
212-
context.Set(r, tokenKey, mask(realToken, r))
212+
r = contextSave(r, tokenKey, mask(realToken, r))
213213
// Save the field name to the request context
214-
context.Set(r, formKey, cs.opts.FieldName)
214+
r = contextSave(r, formKey, cs.opts.FieldName)
215215

216216
// HTTP methods not defined as idempotent ("safe") under RFC7231 require
217217
// inspection.

helpers.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import (
1616
// a JSON response body. An empty token will be returned if the middleware
1717
// has not been applied (which will fail subsequent validation).
1818
func Token(r *http.Request) string {
19-
if val, ok := context.GetOk(r, tokenKey); ok {
19+
if val, err := contextGet(r, tokenKey); err == nil {
2020
if maskedToken, ok := val.(string); ok {
2121
return maskedToken
2222
}
@@ -29,7 +29,7 @@ func Token(r *http.Request) string {
2929
// This is useful when you want to log the cause of the error or report it to
3030
// client.
3131
func FailureReason(r *http.Request) error {
32-
if val, ok := context.GetOk(r, errorKey); ok {
32+
if val, err := contextGet(r, errorKey); err == nil {
3333
if err, ok := val.(error); ok {
3434
return err
3535
}
@@ -44,8 +44,8 @@ func FailureReason(r *http.Request) error {
4444
// Note: You should not set this without otherwise securing the request from
4545
// CSRF attacks. The primary use-case for this function is to turn off CSRF
4646
// checks for non-browser clients using authorization tokens against your API.
47-
func UnsafeSkipCheck(r *http.Request) {
48-
context.Set(r, skipCheckKey, true)
47+
func UnsafeSkipCheck(r *http.Request) *http.Request {
48+
return contextSave(r, skipCheckKey, true)
4949
}
5050

5151
// TemplateField is a template helper for html/template that provides an <input> field
@@ -60,8 +60,7 @@ func UnsafeSkipCheck(r *http.Request) {
6060
// <input type="hidden" name="gorilla.csrf.Token" value="<token>">
6161
//
6262
func TemplateField(r *http.Request) template.HTML {
63-
name, ok := context.GetOk(r, formKey)
64-
if ok {
63+
if name, err := contextGet(r, formKey); err == nil {
6564
fragment := fmt.Sprintf(`<input type="hidden" name="%s" value="%s">`,
6665
name, Token(r))
6766

helpers_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ func TestUnsafeSkipCSRFCheck(t *testing.T) {
270270
s := http.NewServeMux()
271271
skipCheck := func(h http.Handler) http.Handler {
272272
fn := func(w http.ResponseWriter, r *http.Request) {
273-
UnsafeSkipCheck(r)
273+
r = UnsafeSkipCheck(r)
274274
h.ServeHTTP(w, r)
275275
}
276276

0 commit comments

Comments
 (0)