Skip to content

Commit f1032b5

Browse files
authored
Context tools (#19)
* add context transformation logic
1 parent fd68fe3 commit f1032b5

File tree

4 files changed

+318
-0
lines changed

4 files changed

+318
-0
lines changed

context/context.go

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package context
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"strings"
7+
8+
"github.com/utilitywarehouse/castle-go"
9+
http_internal "github.com/utilitywarehouse/castle-go/http"
10+
)
11+
12+
type contextKey string
13+
14+
func (c contextKey) String() string {
15+
return "castle context key " + string(c)
16+
}
17+
18+
var castleCtxKey = contextKey("castle_context")
19+
20+
// ToCtxFromRequest adds the token and other request information (i.e. castle context) to the context.
21+
func ToCtxFromRequest(ctx context.Context, r *http.Request) context.Context {
22+
castleCtx := castle.Context{
23+
RequestToken: func() string {
24+
// grab the token from header if it exists
25+
if tkn := tokenFromHTTPHeader(r.Header); tkn != "" {
26+
return tkn
27+
}
28+
29+
// otherwise, try grabbing it from form
30+
return tokenFromHTTPForm(r)
31+
}(),
32+
IP: http_internal.IPFromRequest(r),
33+
Headers: FilterHeaders(r.Header), // pass in as much context as possible
34+
}
35+
return context.WithValue(ctx, castleCtxKey, castleCtx)
36+
}
37+
38+
func FromCtx(ctx context.Context) *castle.Context {
39+
castleCtx, ok := ctx.Value(castleCtxKey).(castle.Context)
40+
if ok {
41+
return &castleCtx
42+
}
43+
return nil
44+
}
45+
46+
func tokenFromHTTPHeader(header http.Header) string {
47+
// recommended header name
48+
if t := header.Get("X-Castle-Request-Token"); t != "" {
49+
return t
50+
}
51+
// header name used in the frontends
52+
if t := header.Get("Castle-Token"); t != "" {
53+
return t
54+
}
55+
return ""
56+
}
57+
58+
func tokenFromHTTPForm(r *http.Request) string {
59+
// ParseForm is idempotent, so it's safe to call from anywhere
60+
if err := r.ParseForm(); err != nil {
61+
return ""
62+
}
63+
64+
return r.Form.Get("castle_request_token")
65+
}
66+
67+
func FilterHeaders(hs http.Header) map[string]string {
68+
castleHeaders := make(map[string]string)
69+
for key, value := range hs {
70+
// Ensure cookies or authorization are never sent along.
71+
// Everything else is fair game.
72+
if _, ok := disallowedHeaders[strings.ToLower(key)]; ok {
73+
continue
74+
}
75+
// View: https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html
76+
castleHeaders[key] = strings.Join(value, ", ")
77+
}
78+
return castleHeaders
79+
}
80+
81+
var disallowedHeaders = map[string]struct{}{
82+
"cookie": {},
83+
"authorization": {},
84+
}

context/context_test.go

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package context
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
11+
"github.com/utilitywarehouse/castle-go"
12+
)
13+
14+
func TestToCtxFromRequest(t *testing.T) {
15+
tests := map[string]struct {
16+
input http.Request
17+
expected castle.Context
18+
}{
19+
"castle token on header": {
20+
input: func() http.Request {
21+
req := httptest.NewRequest(http.MethodPost, "http://example.com", nil)
22+
req.Header.Set("X-Castle-Request-Token", "foo")
23+
req.RemoteAddr = "2.2.2.2"
24+
return *req
25+
}(),
26+
expected: castle.Context{
27+
IP: "2.2.2.2",
28+
Headers: map[string]string{"X-Castle-Request-Token": "foo"},
29+
RequestToken: "foo",
30+
},
31+
},
32+
"castle token in form": {
33+
input: func() http.Request {
34+
req := httptest.NewRequest(http.MethodPost, "http://example.com/bar?castle_request_token=bar", nil)
35+
req.RemoteAddr = "2.2.2.2"
36+
return *req
37+
}(),
38+
expected: castle.Context{
39+
IP: "2.2.2.2",
40+
Headers: map[string]string{},
41+
RequestToken: "bar",
42+
},
43+
},
44+
"no castle token": {
45+
input: func() http.Request {
46+
req := http.Request{}
47+
req.RemoteAddr = "2.2.2.2"
48+
49+
return req
50+
}(),
51+
expected: castle.Context{
52+
IP: "2.2.2.2",
53+
Headers: map[string]string{},
54+
RequestToken: "",
55+
},
56+
},
57+
}
58+
for name, test := range tests {
59+
test := test
60+
61+
t.Run(name, func(t *testing.T) {
62+
ctx := context.Background()
63+
64+
gotCtx := ToCtxFromRequest(ctx, &test.input)
65+
got := FromCtx(gotCtx)
66+
assert.Equal(t, test.expected, *got)
67+
})
68+
}
69+
}

http/ip.go

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
package http
2+
3+
import (
4+
"errors"
5+
"fmt"
6+
"net"
7+
"net/http"
8+
"strings"
9+
)
10+
11+
var cidrs []*net.IPNet
12+
13+
func init() {
14+
maxCidrBlocks := []string{
15+
"127.0.0.1/8", // localhost
16+
"10.0.0.0/8", // 24-bit block
17+
"172.16.0.0/12", // 20-bit block
18+
"192.168.0.0/16", // 16-bit block
19+
"169.254.0.0/16", // link local address
20+
"::1/128", // localhost IPv6
21+
"fc00::/7", // unique local address IPv6
22+
"fe80::/10", // link local address IPv6
23+
}
24+
25+
cidrs = make([]*net.IPNet, len(maxCidrBlocks))
26+
for i, maxCidrBlock := range maxCidrBlocks {
27+
_, cidr, err := net.ParseCIDR(maxCidrBlock)
28+
if err != nil {
29+
panic(fmt.Sprintf("failed to parse CIDR block %q: %v", maxCidrBlock, err))
30+
}
31+
cidrs[i] = cidr
32+
}
33+
}
34+
35+
// IPFromRequest return client's real public IP address from http request headers.
36+
func IPFromRequest(r *http.Request) string {
37+
// If we have it, return this first.
38+
//
39+
// https://developers.cloudflare.com/fundamentals/get-started/reference/http-request-headers/#cf-connecting-ip
40+
if ip := r.Header.Get("Cf-Connecting-Ip"); ip != "" {
41+
return ip
42+
}
43+
44+
// If we have it, try to return the first global address in X-Forwarded-For
45+
for _, ip := range strings.Split(r.Header.Get("X-Forwarded-For"), ",") {
46+
ip = strings.TrimSpace(ip)
47+
isPrivate, err := isPrivateAddress(ip)
48+
if !isPrivate && err == nil {
49+
return ip
50+
}
51+
}
52+
53+
// Check X-Real-Ip header next
54+
if ip := r.Header.Get("X-Real-Ip"); ip != "" {
55+
return ip
56+
}
57+
58+
// If all else fails, return the remote address
59+
//
60+
// If there are colon in remote address, remove the port number
61+
// otherwise, return remote address as is
62+
var ip string
63+
if strings.ContainsRune(r.RemoteAddr, ':') {
64+
ip, _, _ = net.SplitHostPort(r.RemoteAddr) //nolint:errcheck
65+
} else {
66+
ip = r.RemoteAddr
67+
}
68+
return ip
69+
}
70+
71+
// isPrivateAddress works by checking if the address is under private CIDR blocks.
72+
// List of private CIDR blocks can be seen on :
73+
//
74+
// https://en.wikipedia.org/wiki/Private_network
75+
//
76+
// https://en.wikipedia.org/wiki/Link-local_address
77+
func isPrivateAddress(address string) (bool, error) {
78+
ipAddress := net.ParseIP(address)
79+
if ipAddress == nil {
80+
return false, errors.New("address is not valid")
81+
}
82+
83+
for i := range cidrs {
84+
if cidrs[i].Contains(ipAddress) {
85+
return true, nil
86+
}
87+
}
88+
89+
return false, nil
90+
}

http/ip_test.go

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package http_test
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
9+
http_internal "github.com/utilitywarehouse/castle-go/http"
10+
)
11+
12+
func TestIPFromRequest(t *testing.T) {
13+
tests := map[string]struct {
14+
input *http.Request
15+
expected string
16+
}{
17+
"empty": {
18+
input: httpRequest(map[string]string{
19+
"X-Real-Ip": "",
20+
"X-Forwarded-For": "",
21+
"Cf-Connecting-Ip": "",
22+
}, ""),
23+
expected: "",
24+
},
25+
"cf-connecting-ip": {
26+
input: httpRequest(map[string]string{
27+
"X-Real-Ip": "foo",
28+
"X-Forwarded-For": "bar",
29+
"Cf-Connecting-Ip": "cf-connecting-ip",
30+
}, "foobar"),
31+
expected: "cf-connecting-ip",
32+
},
33+
"x-forwarded-for": {
34+
input: httpRequest(map[string]string{
35+
"X-Real-Ip": "foo",
36+
"X-Forwarded-For": "127.0.0.1, 109.14.23.2",
37+
"Cf-Connecting-Ip": "",
38+
}, "foobar"),
39+
expected: "109.14.23.2",
40+
},
41+
"x-real-ip": {
42+
input: httpRequest(map[string]string{
43+
"X-Real-Ip": "x-real-ip",
44+
"X-Forwarded-For": "",
45+
"Cf-Connecting-Ip": "",
46+
}, "foobar"),
47+
expected: "x-real-ip",
48+
},
49+
"remote-addr": {
50+
input: httpRequest(map[string]string{
51+
"X-Real-Ip": "",
52+
"X-Forwarded-For": "",
53+
"Cf-Connecting-Ip": "",
54+
}, "remote-addr:8080"),
55+
expected: "remote-addr",
56+
},
57+
}
58+
for name, test := range tests {
59+
t.Run(name, func(t *testing.T) {
60+
got := http_internal.IPFromRequest(test.input)
61+
assert.Equal(t, test.expected, got)
62+
})
63+
}
64+
}
65+
66+
func httpRequest(headers map[string]string, remoteAddr string) *http.Request {
67+
r := &http.Request{
68+
RemoteAddr: remoteAddr,
69+
Header: make(http.Header),
70+
}
71+
for k, v := range headers {
72+
r.Header.Set(k, v)
73+
}
74+
return r
75+
}

0 commit comments

Comments
 (0)