Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Context tools #19

Merged
merged 2 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions context/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package context

import (
"context"
"net/http"
"strings"

"github.com/utilitywarehouse/castle-go"
http_internal "github.com/utilitywarehouse/castle-go/http"
)

type contextKey string

func (c contextKey) String() string {
return "castle context key " + string(c)
}

var castleCtxKey = contextKey("castle_context")

// ToCtxFromRequest adds the token and other request information (i.e. castle context) to the context.
func ToCtxFromRequest(ctx context.Context, r *http.Request) context.Context {
castleCtx := castle.Context{
RequestToken: func() string {
// grab the token from header if it exists
if tkn := tokenFromHTTPHeader(r.Header); tkn != "" {
return tkn
}

// otherwise, try grabbing it from form
return tokenFromHTTPForm(r)
}(),
IP: http_internal.IPFromRequest(r),
Headers: FilterHeaders(r.Header), // pass in as much context as possible
}
return context.WithValue(ctx, castleCtxKey, castleCtx)
}

func FromCtx(ctx context.Context) *castle.Context {
castleCtx, ok := ctx.Value(castleCtxKey).(castle.Context)
if ok {
return &castleCtx
}
return nil
}

func tokenFromHTTPHeader(header http.Header) string {
// recommended header name
if t := header.Get("X-Castle-Request-Token"); t != "" {
return t
}
// header name used in the frontends
if t := header.Get("Castle-Token"); t != "" {
return t
}
return ""
}

func tokenFromHTTPForm(r *http.Request) string {
// ParseForm is idempotent, so it's safe to call from anywhere
if err := r.ParseForm(); err != nil {
return ""
}

return r.Form.Get("castle_request_token")
}

func FilterHeaders(hs http.Header) map[string]string {
castleHeaders := make(map[string]string)
for key, value := range hs {
// Ensure cookies or authorization are never sent along.
// Everything else is fair game.
if _, ok := disallowedHeaders[strings.ToLower(key)]; ok {
continue
}
// View: https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html
castleHeaders[key] = strings.Join(value, ", ")
}
return castleHeaders
}

var disallowedHeaders = map[string]struct{}{
"cookie": {},
"authorization": {},
}
69 changes: 69 additions & 0 deletions context/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package context

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"

"github.com/utilitywarehouse/castle-go"
)

func TestToCtxFromRequest(t *testing.T) {
tests := map[string]struct {
input http.Request
expected castle.Context
}{
"castle token on header": {
input: func() http.Request {
req := httptest.NewRequest(http.MethodPost, "http://example.com", nil)
req.Header.Set("X-Castle-Request-Token", "foo")
req.RemoteAddr = "2.2.2.2"
return *req
}(),
expected: castle.Context{
IP: "2.2.2.2",
Headers: map[string]string{"X-Castle-Request-Token": "foo"},
RequestToken: "foo",
},
},
"castle token in form": {
input: func() http.Request {
req := httptest.NewRequest(http.MethodPost, "http://example.com/bar?castle_request_token=bar", nil)
req.RemoteAddr = "2.2.2.2"
return *req
}(),
expected: castle.Context{
IP: "2.2.2.2",
Headers: map[string]string{},
RequestToken: "bar",
},
},
"no castle token": {
input: func() http.Request {
req := http.Request{}
req.RemoteAddr = "2.2.2.2"

return req
}(),
expected: castle.Context{
IP: "2.2.2.2",
Headers: map[string]string{},
RequestToken: "",
},
},
}
for name, test := range tests {
test := test

t.Run(name, func(t *testing.T) {
ctx := context.Background()

gotCtx := ToCtxFromRequest(ctx, &test.input)
got := FromCtx(gotCtx)
assert.Equal(t, test.expected, *got)
})
}
}
90 changes: 90 additions & 0 deletions http/ip.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package http

import (
"errors"
"fmt"
"net"
"net/http"
"strings"
)

var cidrs []*net.IPNet

func init() {
maxCidrBlocks := []string{
"127.0.0.1/8", // localhost
"10.0.0.0/8", // 24-bit block
"172.16.0.0/12", // 20-bit block
"192.168.0.0/16", // 16-bit block
"169.254.0.0/16", // link local address
"::1/128", // localhost IPv6
"fc00::/7", // unique local address IPv6
"fe80::/10", // link local address IPv6
}

cidrs = make([]*net.IPNet, len(maxCidrBlocks))
for i, maxCidrBlock := range maxCidrBlocks {
_, cidr, err := net.ParseCIDR(maxCidrBlock)
if err != nil {
panic(fmt.Sprintf("failed to parse CIDR block %q: %v", maxCidrBlock, err))
}
cidrs[i] = cidr
}
}

// IPFromRequest return client's real public IP address from http request headers.
func IPFromRequest(r *http.Request) string {
// If we have it, return this first.
//
// https://developers.cloudflare.com/fundamentals/get-started/reference/http-request-headers/#cf-connecting-ip
if ip := r.Header.Get("Cf-Connecting-Ip"); ip != "" {
return ip
}

// If we have it, try to return the first global address in X-Forwarded-For
for _, ip := range strings.Split(r.Header.Get("X-Forwarded-For"), ",") {
ip = strings.TrimSpace(ip)
isPrivate, err := isPrivateAddress(ip)
if !isPrivate && err == nil {
return ip
}
}

// Check X-Real-Ip header next
if ip := r.Header.Get("X-Real-Ip"); ip != "" {
return ip
}

// If all else fails, return the remote address
//
// If there are colon in remote address, remove the port number
// otherwise, return remote address as is
var ip string
if strings.ContainsRune(r.RemoteAddr, ':') {
ip, _, _ = net.SplitHostPort(r.RemoteAddr) //nolint:errcheck
} else {
ip = r.RemoteAddr
}
return ip
}

// isPrivateAddress works by checking if the address is under private CIDR blocks.
// List of private CIDR blocks can be seen on :
//
// https://en.wikipedia.org/wiki/Private_network
//
// https://en.wikipedia.org/wiki/Link-local_address
func isPrivateAddress(address string) (bool, error) {
ipAddress := net.ParseIP(address)
if ipAddress == nil {
return false, errors.New("address is not valid")
}

for i := range cidrs {
if cidrs[i].Contains(ipAddress) {
return true, nil
}
}

return false, nil
}
75 changes: 75 additions & 0 deletions http/ip_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package http_test

import (
"net/http"
"testing"

"github.com/stretchr/testify/assert"

http_internal "github.com/utilitywarehouse/castle-go/http"
)

func TestIPFromRequest(t *testing.T) {
tests := map[string]struct {
input *http.Request
expected string
}{
"empty": {
input: httpRequest(map[string]string{
"X-Real-Ip": "",
"X-Forwarded-For": "",
"Cf-Connecting-Ip": "",
}, ""),
expected: "",
},
"cf-connecting-ip": {
input: httpRequest(map[string]string{
"X-Real-Ip": "foo",
"X-Forwarded-For": "bar",
"Cf-Connecting-Ip": "cf-connecting-ip",
}, "foobar"),
expected: "cf-connecting-ip",
},
"x-forwarded-for": {
input: httpRequest(map[string]string{
"X-Real-Ip": "foo",
"X-Forwarded-For": "127.0.0.1, 109.14.23.2",
"Cf-Connecting-Ip": "",
}, "foobar"),
expected: "109.14.23.2",
},
"x-real-ip": {
input: httpRequest(map[string]string{
"X-Real-Ip": "x-real-ip",
"X-Forwarded-For": "",
"Cf-Connecting-Ip": "",
}, "foobar"),
expected: "x-real-ip",
},
"remote-addr": {
input: httpRequest(map[string]string{
"X-Real-Ip": "",
"X-Forwarded-For": "",
"Cf-Connecting-Ip": "",
}, "remote-addr:8080"),
expected: "remote-addr",
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
got := http_internal.IPFromRequest(test.input)
assert.Equal(t, test.expected, got)
})
}
}

func httpRequest(headers map[string]string, remoteAddr string) *http.Request {
r := &http.Request{
RemoteAddr: remoteAddr,
Header: make(http.Header),
}
for k, v := range headers {
r.Header.Set(k, v)
}
return r
}
Loading