diff --git a/context/context.go b/context/context.go new file mode 100644 index 0000000..d97dfbc --- /dev/null +++ b/context/context.go @@ -0,0 +1,84 @@ +package context + +import ( + "context" + "net/http" + "strings" + + "github.com/utilitywarehouse/castle-go" + http_internal "github.com/utilitywarehouse/castle-go/http" +) + +type contextKey string + +func (c contextKey) String() string { + return "castle context key " + string(c) +} + +var castleCtxKey = contextKey("castle_context") + +// ToCtxFromRequest adds the token and other request information (i.e. castle context) to the context. +func ToCtxFromRequest(ctx context.Context, r *http.Request) context.Context { + castleCtx := castle.Context{ + RequestToken: func() string { + // grab the token from header if it exists + if tkn := tokenFromHTTPHeader(r.Header); tkn != "" { + return tkn + } + + // otherwise, try grabbing it from form + return tokenFromHTTPForm(r) + }(), + IP: http_internal.IPFromRequest(r), + Headers: FilterHeaders(r.Header), // pass in as much context as possible + } + return context.WithValue(ctx, castleCtxKey, castleCtx) +} + +func FromCtx(ctx context.Context) *castle.Context { + castleCtx, ok := ctx.Value(castleCtxKey).(castle.Context) + if ok { + return &castleCtx + } + return nil +} + +func tokenFromHTTPHeader(header http.Header) string { + // recommended header name + if t := header.Get("X-Castle-Request-Token"); t != "" { + return t + } + // header name used in the frontends + if t := header.Get("Castle-Token"); t != "" { + return t + } + return "" +} + +func tokenFromHTTPForm(r *http.Request) string { + // ParseForm is idempotent, so it's safe to call from anywhere + if err := r.ParseForm(); err != nil { + return "" + } + + return r.Form.Get("castle_request_token") +} + +func FilterHeaders(hs http.Header) map[string]string { + castleHeaders := make(map[string]string) + for key, value := range hs { + // Ensure cookies or authorization are never sent along. + // Everything else is fair game. + if _, ok := disallowedHeaders[strings.ToLower(key)]; ok { + continue + } + // View: https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html + castleHeaders[key] = strings.Join(value, ", ") + } + return castleHeaders +} + +var disallowedHeaders = map[string]struct{}{ + "cookie": {}, + "authorization": {}, +} diff --git a/context/context_test.go b/context/context_test.go new file mode 100644 index 0000000..8ba817c --- /dev/null +++ b/context/context_test.go @@ -0,0 +1,69 @@ +package context + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/utilitywarehouse/castle-go" +) + +func TestToCtxFromRequest(t *testing.T) { + tests := map[string]struct { + input http.Request + expected castle.Context + }{ + "castle token on header": { + input: func() http.Request { + req := httptest.NewRequest(http.MethodPost, "http://example.com", nil) + req.Header.Set("X-Castle-Request-Token", "foo") + req.RemoteAddr = "2.2.2.2" + return *req + }(), + expected: castle.Context{ + IP: "2.2.2.2", + Headers: map[string]string{"X-Castle-Request-Token": "foo"}, + RequestToken: "foo", + }, + }, + "castle token in form": { + input: func() http.Request { + req := httptest.NewRequest(http.MethodPost, "http://example.com/bar?castle_request_token=bar", nil) + req.RemoteAddr = "2.2.2.2" + return *req + }(), + expected: castle.Context{ + IP: "2.2.2.2", + Headers: map[string]string{}, + RequestToken: "bar", + }, + }, + "no castle token": { + input: func() http.Request { + req := http.Request{} + req.RemoteAddr = "2.2.2.2" + + return req + }(), + expected: castle.Context{ + IP: "2.2.2.2", + Headers: map[string]string{}, + RequestToken: "", + }, + }, + } + for name, test := range tests { + test := test + + t.Run(name, func(t *testing.T) { + ctx := context.Background() + + gotCtx := ToCtxFromRequest(ctx, &test.input) + got := FromCtx(gotCtx) + assert.Equal(t, test.expected, *got) + }) + } +} diff --git a/http/ip.go b/http/ip.go new file mode 100644 index 0000000..f3092c8 --- /dev/null +++ b/http/ip.go @@ -0,0 +1,90 @@ +package http + +import ( + "errors" + "fmt" + "net" + "net/http" + "strings" +) + +var cidrs []*net.IPNet + +func init() { + maxCidrBlocks := []string{ + "127.0.0.1/8", // localhost + "10.0.0.0/8", // 24-bit block + "172.16.0.0/12", // 20-bit block + "192.168.0.0/16", // 16-bit block + "169.254.0.0/16", // link local address + "::1/128", // localhost IPv6 + "fc00::/7", // unique local address IPv6 + "fe80::/10", // link local address IPv6 + } + + cidrs = make([]*net.IPNet, len(maxCidrBlocks)) + for i, maxCidrBlock := range maxCidrBlocks { + _, cidr, err := net.ParseCIDR(maxCidrBlock) + if err != nil { + panic(fmt.Sprintf("failed to parse CIDR block %q: %v", maxCidrBlock, err)) + } + cidrs[i] = cidr + } +} + +// IPFromRequest return client's real public IP address from http request headers. +func IPFromRequest(r *http.Request) string { + // If we have it, return this first. + // + // https://developers.cloudflare.com/fundamentals/get-started/reference/http-request-headers/#cf-connecting-ip + if ip := r.Header.Get("Cf-Connecting-Ip"); ip != "" { + return ip + } + + // If we have it, try to return the first global address in X-Forwarded-For + for _, ip := range strings.Split(r.Header.Get("X-Forwarded-For"), ",") { + ip = strings.TrimSpace(ip) + isPrivate, err := isPrivateAddress(ip) + if !isPrivate && err == nil { + return ip + } + } + + // Check X-Real-Ip header next + if ip := r.Header.Get("X-Real-Ip"); ip != "" { + return ip + } + + // If all else fails, return the remote address + // + // If there are colon in remote address, remove the port number + // otherwise, return remote address as is + var ip string + if strings.ContainsRune(r.RemoteAddr, ':') { + ip, _, _ = net.SplitHostPort(r.RemoteAddr) //nolint:errcheck + } else { + ip = r.RemoteAddr + } + return ip +} + +// isPrivateAddress works by checking if the address is under private CIDR blocks. +// List of private CIDR blocks can be seen on : +// +// https://en.wikipedia.org/wiki/Private_network +// +// https://en.wikipedia.org/wiki/Link-local_address +func isPrivateAddress(address string) (bool, error) { + ipAddress := net.ParseIP(address) + if ipAddress == nil { + return false, errors.New("address is not valid") + } + + for i := range cidrs { + if cidrs[i].Contains(ipAddress) { + return true, nil + } + } + + return false, nil +} diff --git a/http/ip_test.go b/http/ip_test.go new file mode 100644 index 0000000..bf013a4 --- /dev/null +++ b/http/ip_test.go @@ -0,0 +1,75 @@ +package http_test + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + + http_internal "github.com/utilitywarehouse/castle-go/http" +) + +func TestIPFromRequest(t *testing.T) { + tests := map[string]struct { + input *http.Request + expected string + }{ + "empty": { + input: httpRequest(map[string]string{ + "X-Real-Ip": "", + "X-Forwarded-For": "", + "Cf-Connecting-Ip": "", + }, ""), + expected: "", + }, + "cf-connecting-ip": { + input: httpRequest(map[string]string{ + "X-Real-Ip": "foo", + "X-Forwarded-For": "bar", + "Cf-Connecting-Ip": "cf-connecting-ip", + }, "foobar"), + expected: "cf-connecting-ip", + }, + "x-forwarded-for": { + input: httpRequest(map[string]string{ + "X-Real-Ip": "foo", + "X-Forwarded-For": "127.0.0.1, 109.14.23.2", + "Cf-Connecting-Ip": "", + }, "foobar"), + expected: "109.14.23.2", + }, + "x-real-ip": { + input: httpRequest(map[string]string{ + "X-Real-Ip": "x-real-ip", + "X-Forwarded-For": "", + "Cf-Connecting-Ip": "", + }, "foobar"), + expected: "x-real-ip", + }, + "remote-addr": { + input: httpRequest(map[string]string{ + "X-Real-Ip": "", + "X-Forwarded-For": "", + "Cf-Connecting-Ip": "", + }, "remote-addr:8080"), + expected: "remote-addr", + }, + } + for name, test := range tests { + t.Run(name, func(t *testing.T) { + got := http_internal.IPFromRequest(test.input) + assert.Equal(t, test.expected, got) + }) + } +} + +func httpRequest(headers map[string]string, remoteAddr string) *http.Request { + r := &http.Request{ + RemoteAddr: remoteAddr, + Header: make(http.Header), + } + for k, v := range headers { + r.Header.Set(k, v) + } + return r +}