Skip to content

Commit e711d01

Browse files
committed
add context transformation logic
1 parent fd68fe3 commit e711d01

File tree

4 files changed

+318
-0
lines changed

4 files changed

+318
-0
lines changed

context/context.go

+84
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package context
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"strings"
7+
8+
"github.com/utilitywarehouse/castle-go"
9+
http_internal "github.com/utilitywarehouse/castle-go/http"
10+
)
11+
12+
type contextKey string
13+
14+
func (c contextKey) String() string {
15+
return "castle context key " + string(c)
16+
}
17+
18+
var castleCtxKey = contextKey("castle_context")
19+
20+
// ToCtxFromRequest adds the token and other request information (i.e. castle context) to the context.
21+
func ToCtxFromRequest(ctx context.Context, r *http.Request) context.Context {
22+
castleCtx := castle.Context{
23+
RequestToken: func() string {
24+
// grab the token from header if it exists
25+
if tkn := tokenFromHTTPHeader(r.Header); tkn != "" {
26+
return tkn
27+
}
28+
29+
// otherwise, try grabbing it from form
30+
return tokenFromHTTPForm(r)
31+
}(),
32+
IP: http_internal.IPFromRequest(r),
33+
Headers: FilterHeaders(r.Header), // pass in as much context as possible
34+
}
35+
return context.WithValue(ctx, castleCtxKey, castleCtx)
36+
}
37+
38+
func FromCtx(ctx context.Context) *castle.Context {
39+
castleCtx, ok := ctx.Value(castleCtxKey).(castle.Context)
40+
if ok {
41+
return &castleCtx
42+
}
43+
return nil
44+
}
45+
46+
func tokenFromHTTPHeader(header http.Header) string {
47+
// recommended header name
48+
if t := header.Get("X-Castle-Request-Token"); t != "" {
49+
return t
50+
}
51+
// header name used in the frontends
52+
if t := header.Get("Castle-Token"); t != "" {
53+
return t
54+
}
55+
return ""
56+
}
57+
58+
func tokenFromHTTPForm(r *http.Request) string {
59+
// ParseForm is idempotent, so it's safe to call from anywhere
60+
if err := r.ParseForm(); err != nil {
61+
return ""
62+
}
63+
64+
return r.Form.Get("castle_request_token")
65+
}
66+
67+
func FilterHeaders(hs http.Header) map[string]string {
68+
castleHeaders := make(map[string]string)
69+
for key, value := range hs {
70+
// Ensure cookies or authorization are never sent along.
71+
// Everything else is fair game.
72+
if _, ok := disallowedHeaders[strings.ToLower(key)]; ok {
73+
continue
74+
}
75+
// View: https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html
76+
castleHeaders[key] = strings.Join(value, ", ")
77+
}
78+
return castleHeaders
79+
}
80+
81+
var disallowedHeaders = map[string]struct{}{
82+
"cookie": {},
83+
"authorization": {},
84+
}

context/context_test.go

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package context
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/stretchr/testify/assert"
10+
"github.com/utilitywarehouse/castle-go"
11+
)
12+
13+
func TestToCtxFromRequest(t *testing.T) {
14+
tests := map[string]struct {
15+
input http.Request
16+
expected castle.Context
17+
}{
18+
"castle token on header": {
19+
input: func() http.Request {
20+
req := httptest.NewRequest(http.MethodPost, "http://example.com", nil)
21+
req.Header.Set("X-Castle-Request-Token", "foo")
22+
req.RemoteAddr = "2.2.2.2"
23+
return *req
24+
}(),
25+
expected: castle.Context{
26+
IP: "2.2.2.2",
27+
Headers: map[string]string{"X-Castle-Request-Token": "foo"},
28+
RequestToken: "foo",
29+
},
30+
},
31+
"castle token in form": {
32+
input: func() http.Request {
33+
req := httptest.NewRequest(http.MethodPost, "http://example.com/bar?castle_request_token=bar", nil)
34+
req.RemoteAddr = "2.2.2.2"
35+
return *req
36+
}(),
37+
expected: castle.Context{
38+
IP: "2.2.2.2",
39+
Headers: map[string]string{},
40+
RequestToken: "bar",
41+
},
42+
},
43+
"no castle token": {
44+
input: func() http.Request {
45+
req := http.Request{}
46+
req.RemoteAddr = "2.2.2.2"
47+
48+
return req
49+
}(),
50+
expected: castle.Context{
51+
IP: "2.2.2.2",
52+
Headers: map[string]string{},
53+
RequestToken: "",
54+
},
55+
},
56+
}
57+
for name, test := range tests {
58+
test := test
59+
60+
t.Run(name, func(t *testing.T) {
61+
ctx := context.Background()
62+
63+
gotCtx := ToCtxFromRequest(ctx, &test.input)
64+
got := FromCtx(gotCtx)
65+
assert.Equal(t, test.expected, *got)
66+
})
67+
}
68+
}

http/ip.go

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
2+
package http
3+
4+
import (
5+
"errors"
6+
"fmt"
7+
"net"
8+
"net/http"
9+
"strings"
10+
)
11+
12+
var cidrs []*net.IPNet
13+
14+
func init() {
15+
maxCidrBlocks := []string{
16+
"127.0.0.1/8", // localhost
17+
"10.0.0.0/8", // 24-bit block
18+
"172.16.0.0/12", // 20-bit block
19+
"192.168.0.0/16", // 16-bit block
20+
"169.254.0.0/16", // link local address
21+
"::1/128", // localhost IPv6
22+
"fc00::/7", // unique local address IPv6
23+
"fe80::/10", // link local address IPv6
24+
}
25+
26+
cidrs = make([]*net.IPNet, len(maxCidrBlocks))
27+
for i, maxCidrBlock := range maxCidrBlocks {
28+
_, cidr, err := net.ParseCIDR(maxCidrBlock)
29+
if err != nil {
30+
panic(fmt.Sprintf("failed to parse CIDR block %q: %v", maxCidrBlock, err))
31+
}
32+
cidrs[i] = cidr
33+
}
34+
}
35+
36+
// IPFromRequest return client's real public IP address from http request headers.
37+
func IPFromRequest(r *http.Request) string {
38+
// If we have it, return this first.
39+
//
40+
// https://developers.cloudflare.com/fundamentals/get-started/reference/http-request-headers/#cf-connecting-ip
41+
if ip := r.Header.Get("Cf-Connecting-Ip"); ip != "" {
42+
return ip
43+
}
44+
45+
// If we have it, try to return the first global address in X-Forwarded-For
46+
for _, ip := range strings.Split(r.Header.Get("X-Forwarded-For"), ",") {
47+
ip = strings.TrimSpace(ip)
48+
isPrivate, err := isPrivateAddress(ip)
49+
if !isPrivate && err == nil {
50+
return ip
51+
}
52+
}
53+
54+
// Check X-Real-Ip header next
55+
if ip := r.Header.Get("X-Real-Ip"); ip != "" {
56+
return ip
57+
}
58+
59+
// If all else fails, return the remote address
60+
//
61+
// If there are colon in remote address, remove the port number
62+
// otherwise, return remote address as is
63+
var ip string
64+
if strings.ContainsRune(r.RemoteAddr, ':') {
65+
ip, _, _ = net.SplitHostPort(r.RemoteAddr) //nolint:errcheck
66+
} else {
67+
ip = r.RemoteAddr
68+
}
69+
return ip
70+
}
71+
72+
// isPrivateAddress works by checking if the address is under private CIDR blocks.
73+
// List of private CIDR blocks can be seen on :
74+
//
75+
// https://en.wikipedia.org/wiki/Private_network
76+
//
77+
// https://en.wikipedia.org/wiki/Link-local_address
78+
func isPrivateAddress(address string) (bool, error) {
79+
ipAddress := net.ParseIP(address)
80+
if ipAddress == nil {
81+
return false, errors.New("address is not valid")
82+
}
83+
84+
for i := range cidrs {
85+
if cidrs[i].Contains(ipAddress) {
86+
return true, nil
87+
}
88+
}
89+
90+
return false, nil
91+
}

http/ip_test.go

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package http_test
2+
3+
import (
4+
"net/http"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
9+
http_internal "github.com/utilitywarehouse/castle-go/http"
10+
)
11+
12+
func TestIPFromRequest(t *testing.T) {
13+
tests := map[string]struct {
14+
input *http.Request
15+
expected string
16+
}{
17+
"empty": {
18+
input: httpRequest(map[string]string{
19+
"X-Real-Ip": "",
20+
"X-Forwarded-For": "",
21+
"Cf-Connecting-Ip": "",
22+
}, ""),
23+
expected: "",
24+
},
25+
"cf-connecting-ip": {
26+
input: httpRequest(map[string]string{
27+
"X-Real-Ip": "foo",
28+
"X-Forwarded-For": "bar",
29+
"Cf-Connecting-Ip": "cf-connecting-ip",
30+
}, "foobar"),
31+
expected: "cf-connecting-ip",
32+
},
33+
"x-forwarded-for": {
34+
input: httpRequest(map[string]string{
35+
"X-Real-Ip": "foo",
36+
"X-Forwarded-For": "127.0.0.1, 109.14.23.2",
37+
"Cf-Connecting-Ip": "",
38+
}, "foobar"),
39+
expected: "109.14.23.2",
40+
},
41+
"x-real-ip": {
42+
input: httpRequest(map[string]string{
43+
"X-Real-Ip": "x-real-ip",
44+
"X-Forwarded-For": "",
45+
"Cf-Connecting-Ip": "",
46+
}, "foobar"),
47+
expected: "x-real-ip",
48+
},
49+
"remote-addr": {
50+
input: httpRequest(map[string]string{
51+
"X-Real-Ip": "",
52+
"X-Forwarded-For": "",
53+
"Cf-Connecting-Ip": "",
54+
}, "remote-addr:8080"),
55+
expected: "remote-addr",
56+
},
57+
}
58+
for name, test := range tests {
59+
t.Run(name, func(t *testing.T) {
60+
got := http_internal.IPFromRequest(test.input)
61+
assert.Equal(t, test.expected, got)
62+
})
63+
}
64+
}
65+
66+
func httpRequest(headers map[string]string, remoteAddr string) *http.Request {
67+
r := &http.Request{
68+
RemoteAddr: remoteAddr,
69+
Header: make(http.Header),
70+
}
71+
for k, v := range headers {
72+
r.Header.Set(k, v)
73+
}
74+
return r
75+
}

0 commit comments

Comments
 (0)