diff --git a/.gitignore b/.gitignore index a3b9b4e0..348636d5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,27 +1,6 @@ -# Compiled Object files, Static and Dynamic libs (Shared Objects) -*.o -*.a -*.so - -# Folders -_obj -_test - -# Architecture specific extensions/prefixes -*.[568vq] -[568vq].out - -*.cgo1.go -*.cgo2.c -_cgo_defun.c -_cgo_gotypes.go -_cgo_export.* - -_testmain.go - +.* *.exe *.test *.prof - dist/* coverage.txt diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md new file mode 100644 index 00000000..27265e31 --- /dev/null +++ b/DEVELOPMENT.md @@ -0,0 +1,113 @@ +# Development + +## Local development + +For interactive local development, use `make run` to build and run go-httpbin +or `make watch` to automatically re-build and re-run go-httpbin on every +change: + + make run + make watch + +By default, the server will listen on `http://127.0.0.1:8080`, but the host, +port, or any other [configuration option][config] may be overridden by +specifying the relevant environment variables: + + make run PORT=9999 + make run PORT=9999 MAX_DURATION=60s + make watch HOST=0.0.0.0 PORT=8888 + +## Testing + +Run `make test` to run unit tests, using `TEST_ARGS` to pass arguments through +to `go test`: + + make test + make test TEST_ARGS="-v -race -run ^TestDelay" + +### Integration tests + +go-httpbin includes its own minimal WebSocket echo server implementation, and +we use the incredibly helpful [Autobahn Testsuite][] to ensure that the +implementation conforms to the spec. + +These tests can be slow to run (~40 seconds on my machine), so they are not run +by default when using `make test`. + +They are run automatically as part of our extended "CI" test suite, which is +run on every pull request: + + make testci + +### WebSocket development + +When working on the WebSocket implementation, it can also be useful to run +those integration tests directly, like so: + + make testautobahn + +Use the `AUTOBAHN_CASES` var to run a specific subset of the Autobahn tests, +which may or may not include wildcards: + + make testautobahn AUTOBAHN_CASES=6.* + make testautobahn AUTOBAHN_CASES=6.5.* + make testautobahn AUTOBAHN_CASES=6.5.4 + + +### Test coverage + +We use [Codecov][] to measure and track test coverage as part of our continuous +integration test suite. While we strive for as much coverage as possible and +the Codecov CI check is configured with fairly strict requirements, 100% test +coverage is not an explicit goal or requirement for all contributions. + +To view test coverage locally, use + + make testcover + +which will run the full suite of unit and integration tests and pop open a web +browser to view coverage results. + + +## Linting and code style + +Run `make lint` to run our suite of linters and formatters, which include +gofmt, [revive][], and [staticcheck][]: + + make lint + + +## Docker images + +To build a docker image locally: + + make image + +To build a docker image an push it to a remote repository: + + make imagepush + +By default, images will be tagged as `mccutchen/go-httpbin:${COMMIT}` with the +current HEAD commit hash. + +Use `VERSION` to override the tag value + + make imagepush VERSION=v1.2.3 + +or `DOCKER_TAG` to override the remote repo and version at once: + + make imagepush DOCKER_TAG=my-org/my-fork:v1.2.3 + +### Automated docker image builds + +When a new release is created, the [Release][] GitHub Actions workflow +automatically builds and pushes new Docker images for both linux/amd64 and +linux/arm64 architectures. + + +[config]: /README.md#configuration +[revive]: https://github.com/mgechev/revive +[staticcheck]: https://staticcheck.dev/ +[Release]: /.github/workflows/release.yaml +[Codecov]: https://app.codecov.io/gh/mccutchen/go-httpbin +[Autobahn Testsuite]: https://github.com/crossbario/autobahn-testsuite diff --git a/Makefile b/Makefile index 033d7655..ad3f9c7e 100644 --- a/Makefile +++ b/Makefile @@ -38,7 +38,7 @@ buildtests: .PHONY: buildtests clean: - rm -rf $(DIST_PATH) $(COVERAGE_PATH) + rm -rf $(DIST_PATH) $(COVERAGE_PATH) .integrationtests .PHONY: clean @@ -53,14 +53,18 @@ test: # based on codecov.io's documentation: # https://github.com/codecov/example-go/blob/b85638743b972bd0bd2af63421fe513c6f968930/README.md testci: build buildexamples - go test $(TEST_ARGS) $(COVERAGE_ARGS) ./... - git diff --exit-code + AUTOBAHN_TESTS=1 go test $(TEST_ARGS) $(COVERAGE_ARGS) ./... .PHONY: testci testcover: testci go tool cover -html=$(COVERAGE_PATH) .PHONY: testcover +# Run the autobahn fuzzingclient test suite +testautobahn: + AUTOBAHN_TESTS=1 AUTOBAHN_OPEN_REPORT=1 go test -v -run ^TestWebSocketServer$$ $(TEST_ARGS) ./... +.PHONY: autobahntests + lint: test -z "$$(gofmt -d -s -e .)" || (echo "Error: gofmt failed"; gofmt -d -s -e . ; exit 1) go vet ./... diff --git a/README.md b/README.md index 9d3d9a57..58f0c815 100644 --- a/README.md +++ b/README.md @@ -169,17 +169,7 @@ public internet, consider tuning it appropriately: ## Development -```bash -# local development -make -make test -make testcover -make run - -# building & pushing docker images -make image -make imagepush -``` +See [DEVELOPMENT.md][]. ## Motivation & prior art @@ -218,3 +208,4 @@ Compared to [ahmetb/go-httpbin][ahmet]: [Observer]: https://pkg.go.dev/github.com/mccutchen/go-httpbin/v2/httpbin#Observer [Production considerations]: #production-considerations [zerolog]: https://github.com/rs/zerolog +[DEVELOPMENT.md]: ./DEVELOPMENT.md diff --git a/httpbin/handlers.go b/httpbin/handlers.go index a122a016..7ecb7bc3 100644 --- a/httpbin/handlers.go +++ b/httpbin/handlers.go @@ -15,6 +15,7 @@ import ( "time" "github.com/mccutchen/go-httpbin/v2/httpbin/digest" + "github.com/mccutchen/go-httpbin/v2/httpbin/websocket" ) var nilValues = url.Values{} @@ -1112,3 +1113,51 @@ func (h *HTTPBin) Hostname(w http.ResponseWriter, _ *http.Request) { Hostname: h.hostname, }) } + +// WebSocketEcho - simple websocket echo server, where the max fragment size +// and max message size can be controlled by clients. +func (h *HTTPBin) WebSocketEcho(w http.ResponseWriter, r *http.Request) { + var ( + maxFragmentSize = h.MaxBodySize / 2 + maxMessageSize = h.MaxBodySize + q = r.URL.Query() + err error + ) + + if userMaxFragmentSize := q.Get("max_fragment_size"); userMaxFragmentSize != "" { + maxFragmentSize, err = strconv.ParseInt(userMaxFragmentSize, 10, 32) + if err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid max_fragment_size: %w", err)) + return + } else if maxFragmentSize < 1 || maxFragmentSize > h.MaxBodySize { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid max_fragment_size: %d not in range [1, %d]", maxFragmentSize, h.MaxBodySize)) + return + } + } + + if userMaxMessageSize := q.Get("max_message_size"); userMaxMessageSize != "" { + maxMessageSize, err = strconv.ParseInt(userMaxMessageSize, 10, 32) + if err != nil { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid max_message_size: %w", err)) + return + } else if maxMessageSize < 1 || maxMessageSize > h.MaxBodySize { + writeError(w, http.StatusBadRequest, fmt.Errorf("invalid max_message_size: %d not in range [1, %d]", maxMessageSize, h.MaxBodySize)) + return + } + } + + if maxFragmentSize > maxMessageSize { + writeError(w, http.StatusBadRequest, fmt.Errorf("max_fragment_size %d must be less than or equal to max_message_size %d", maxFragmentSize, maxMessageSize)) + return + } + + ws := websocket.New(w, r, websocket.Limits{ + MaxFragmentSize: int(maxFragmentSize), + MaxMessageSize: int(maxMessageSize), + }) + if err := ws.Handshake(); err != nil { + writeError(w, http.StatusBadRequest, err) + return + } + ws.Serve(websocket.EchoHandler) +} diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go index fc5387d8..22f85afb 100644 --- a/httpbin/handlers_test.go +++ b/httpbin/handlers_test.go @@ -2911,6 +2911,78 @@ func TestHostname(t *testing.T) { }) } +func TestWebSocketEcho(t *testing.T) { + // ======================================================================== + // Note: Here we only test input validation for the websocket endpoint. + // + // See websocket/*_test.go for in-depth integration tests of the actual + // websocket implementation. + // ======================================================================== + + handshakeHeaders := map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + } + + t.Run("handshake ok", func(t *testing.T) { + t.Parallel() + + req := newTestRequest(t, http.MethodGet, "/websocket/echo") + for k, v := range handshakeHeaders { + req.Header.Set(k, v) + } + + resp, err := client.Do(req) + assert.NilError(t, err) + assert.StatusCode(t, resp, http.StatusSwitchingProtocols) + }) + + t.Run("handshake failed", func(t *testing.T) { + t.Parallel() + req := newTestRequest(t, http.MethodGet, "/websocket/echo") + resp, err := client.Do(req) + assert.NilError(t, err) + assert.StatusCode(t, resp, http.StatusBadRequest) + }) + + paramTests := []struct { + query string + wantStatus int + }{ + // ok + {"max_fragment_size=1&max_message_size=2", http.StatusSwitchingProtocols}, + {fmt.Sprintf("max_fragment_size=%d&max_message_size=%d", app.MaxBodySize, app.MaxBodySize), http.StatusSwitchingProtocols}, + + // bad max_framgent_size + {"max_fragment_size=-1&max_message_size=2", http.StatusBadRequest}, + {"max_fragment_size=0&max_message_size=2", http.StatusBadRequest}, + {"max_fragment_size=3&max_message_size=2", http.StatusBadRequest}, + {"max_fragment_size=foo&max_message_size=2", http.StatusBadRequest}, + {fmt.Sprintf("max_fragment_size=%d&max_message_size=2", app.MaxBodySize+1), http.StatusBadRequest}, + + // bad max_message_size + {"max_fragment_size=1&max_message_size=0", http.StatusBadRequest}, + {"max_fragment_size=1&max_message_size=-1", http.StatusBadRequest}, + {"max_fragment_size=1&max_message_size=bar", http.StatusBadRequest}, + {fmt.Sprintf("max_fragment_size=1&max_message_size=%d", app.MaxBodySize+1), http.StatusBadRequest}, + } + for _, tc := range paramTests { + tc := tc + t.Run(tc.query, func(t *testing.T) { + t.Parallel() + req := newTestRequest(t, http.MethodGet, "/websocket/echo?"+tc.query) + for k, v := range handshakeHeaders { + req.Header.Set(k, v) + } + resp, err := client.Do(req) + assert.NilError(t, err) + assert.StatusCode(t, resp, tc.wantStatus) + }) + } +} + func newTestServer(handler http.Handler) (*httptest.Server, *http.Client) { srv := httptest.NewServer(handler) client := srv.Client() diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go index 81315fb6..b1b420a3 100644 --- a/httpbin/httpbin.go +++ b/httpbin/httpbin.go @@ -153,6 +153,8 @@ func (h *HTTPBin) Handler() http.Handler { mux.HandleFunc("/dump/request", h.DumpRequest) + mux.HandleFunc("/websocket/echo", h.WebSocketEcho) + // existing httpbin endpoints that we do not support mux.HandleFunc("/brotli", notImplementedHandler) diff --git a/httpbin/middleware.go b/httpbin/middleware.go index 15d507ae..02d070a1 100644 --- a/httpbin/middleware.go +++ b/httpbin/middleware.go @@ -1,8 +1,10 @@ package httpbin import ( + "bufio" "fmt" "log" + "net" "net/http" "time" ) @@ -123,6 +125,10 @@ func (mw *metaResponseWriter) Size() int64 { return mw.size } +func (mw *metaResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return mw.w.(http.Hijacker).Hijack() +} + func observe(o Observer, h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mw := &metaResponseWriter{w: w} diff --git a/httpbin/static/index.html b/httpbin/static/index.html index 1599af8b..4d707eea 100644 --- a/httpbin/static/index.html +++ b/httpbin/static/index.html @@ -115,6 +115,7 @@

ENDPOINTS

  • /unstable Fails half the time, accepts optional failure_rate float and seed integer parameters.
  • /user-agent Returns user-agent.
  • /uuid Generates a UUIDv4 value.
  • +
  • /websocket/echo?max_fragment_size=2048&max_message_size=10240 A WebSocket echo service.
  • /xml Returns some XML
  • diff --git a/httpbin/websocket/websocket.go b/httpbin/websocket/websocket.go new file mode 100644 index 00000000..e85f79a4 --- /dev/null +++ b/httpbin/websocket/websocket.go @@ -0,0 +1,474 @@ +// Package websocket implements a basic websocket server. +package websocket + +import ( + "bufio" + "context" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "errors" + "fmt" + "io" + "net/http" + "strings" + "unicode/utf8" +) + +const requiredVersion = "13" + +// Opcode is a websocket OPCODE. +type Opcode uint8 + +// See the RFC for the set of defined opcodes: +// https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 +const ( + OpcodeContinuation Opcode = 0x0 + OpcodeText Opcode = 0x1 + OpcodeBinary Opcode = 0x2 + OpcodeClose Opcode = 0x8 + OpcodePing Opcode = 0x9 + OpcodePong Opcode = 0xA +) + +// StatusCode is a websocket status code. +type StatusCode uint16 + +// See the RFC for the set of defined status codes: +// https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupported StatusCode = 1003 + StatusNoStatusRcvd StatusCode = 1005 + StatusAbnormalClose StatusCode = 1006 + StatusUnsupportedPayload StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusTooLarge StatusCode = 1009 + StatusTlSHandshake StatusCode = 1015 + StatusServerError StatusCode = 1011 +) + +// Frame is a websocket protocol frame. +type Frame struct { + Fin bool + RSV1 bool + RSV3 bool + RSV2 bool + Opcode Opcode + Payload []byte +} + +// Message is an application-level message from the client, which may be +// constructed from one or more individual protocol frames. +type Message struct { + Binary bool + Payload []byte +} + +// Handler handles a single websocket message. If the returned message is +// non-nil, it will be sent to the client. If an error is returned, the +// connection will be closed. +type Handler func(ctx context.Context, msg *Message) (*Message, error) + +// EchoHandler is a Handler that echoes each incoming message back to the +// client. +var EchoHandler Handler = func(ctx context.Context, msg *Message) (*Message, error) { + return msg, nil +} + +// Limits define the limits imposed on a websocket connection. +type Limits struct { + MaxFragmentSize int + MaxMessageSize int +} + +// WebSocket is a websocket connection. +type WebSocket struct { + w http.ResponseWriter + r *http.Request + maxFragmentSize int + maxMessageSize int + handshook bool +} + +// New creates a new websocket. +func New(w http.ResponseWriter, r *http.Request, limits Limits) *WebSocket { + return &WebSocket{ + w: w, + r: r, + maxFragmentSize: limits.MaxFragmentSize, + maxMessageSize: limits.MaxMessageSize, + } +} + +// Handshake validates the request and performs the WebSocket handshake. If +// Handshake returns nil, only websocket frames should be written to the +// response writer. +func (s *WebSocket) Handshake() error { + if s.handshook { + panic("websocket: handshake already completed") + } + + if strings.ToLower(s.r.Header.Get("Connection")) != "upgrade" { + return fmt.Errorf("missing required `Connection: upgrade` header") + } + if strings.ToLower(s.r.Header.Get("Upgrade")) != "websocket" { + return fmt.Errorf("missing required `Upgrade: websocket` header") + } + if v := s.r.Header.Get("Sec-Websocket-Version"); v != requiredVersion { + return fmt.Errorf("only websocket version %q is supported, got %q", requiredVersion, v) + } + + clientKey := s.r.Header.Get("Sec-Websocket-Key") + if clientKey == "" { + return fmt.Errorf("missing required `Sec-Websocket-Key` header") + } + + s.w.Header().Set("Connection", "upgrade") + s.w.Header().Set("Upgrade", "websocket") + s.w.Header().Set("Sec-Websocket-Accept", acceptKey(clientKey)) + s.w.WriteHeader(http.StatusSwitchingProtocols) + + s.handshook = true + return nil +} + +// Serve handles a websocket connection after the handshake has been completed. +func (s *WebSocket) Serve(handler Handler) { + if !s.handshook { + panic("websocket: serve: handshake not completed") + } + + hj, ok := s.w.(http.Hijacker) + if !ok { + panic("websocket: serve: server does not support hijacking") + } + + conn, buf, err := hj.Hijack() + if err != nil { + panic(fmt.Errorf("websocket: serve: hijack failed: %s", err)) + } + defer conn.Close() + + // errors intentionally ignored here. it's serverLoop's responsibility to + // properly close the websocket connection with a useful error message, and + // any unexpected error returned from serverLoop is not actionable. + _ = s.serveLoop(s.r.Context(), buf, handler) +} + +func (s *WebSocket) serveLoop(ctx context.Context, buf *bufio.ReadWriter, handler Handler) error { + var currentMsg *Message + + for { + select { + case <-ctx.Done(): + return nil + default: + } + + frame, err := nextFrame(buf) + if err != nil { + return writeCloseFrame(buf, StatusServerError, err) + } + + if err := validateFrame(frame, s.maxFragmentSize); err != nil { + return writeCloseFrame(buf, StatusProtocolError, err) + } + + switch frame.Opcode { + case OpcodeBinary, OpcodeText: + if currentMsg != nil { + return writeCloseFrame(buf, StatusProtocolError, errors.New("expected continuation frame")) + } + if frame.Opcode == OpcodeText && !utf8.Valid(frame.Payload) { + return writeCloseFrame(buf, StatusUnsupportedPayload, errors.New("invalid UTF-8")) + } + currentMsg = &Message{ + Binary: frame.Opcode == OpcodeBinary, + Payload: frame.Payload, + } + case OpcodeContinuation: + if currentMsg == nil { + return writeCloseFrame(buf, StatusProtocolError, errors.New("unexpected continuation frame")) + } + if !currentMsg.Binary && !utf8.Valid(frame.Payload) { + return writeCloseFrame(buf, StatusUnsupportedPayload, errors.New("invalid UTF-8")) + } + currentMsg.Payload = append(currentMsg.Payload, frame.Payload...) + if len(currentMsg.Payload) > s.maxMessageSize { + return writeCloseFrame(buf, StatusTooLarge, fmt.Errorf("message size %d exceeds maximum of %d bytes", len(currentMsg.Payload), s.maxMessageSize)) + } + case OpcodeClose: + return writeCloseFrame(buf, StatusNormalClosure, nil) + case OpcodePing: + frame.Opcode = OpcodePong + if err := writeFrame(buf, frame); err != nil { + return err + } + continue + case OpcodePong: + continue + default: + return writeCloseFrame(buf, StatusProtocolError, fmt.Errorf("unsupported opcode: %v", frame.Opcode)) + } + + if frame.Fin { + resp, err := handler(ctx, currentMsg) + if err != nil { + return writeCloseFrame(buf, StatusServerError, err) + } + if resp == nil { + continue + } + for _, respFrame := range frameResponse(resp, s.maxFragmentSize) { + if err := writeFrame(buf, respFrame); err != nil { + return err + } + } + currentMsg = nil + } + } +} + +func nextFrame(buf *bufio.ReadWriter) (*Frame, error) { + bb := make([]byte, 2) + if _, err := io.ReadFull(buf, bb); err != nil { + return nil, err + } + + b0 := bb[0] + b1 := bb[1] + + var ( + fin = b0&0b10000000 != 0 + rsv1 = b0&0b01000000 != 0 + rsv2 = b0&0b00100000 != 0 + rsv3 = b0&0b00010000 != 0 + opcode = Opcode(b0 & 0b00001111) + ) + + // Per https://datatracker.ietf.org/doc/html/rfc6455#section-5.2, all + // client frames must be masked. + if masked := b1 & 0b10000000; masked == 0 { + return nil, fmt.Errorf("received unmasked client frame") + } + + var payloadLength uint64 + switch { + case b1-128 <= 125: + // Payload length is directly represented in the second byte + payloadLength = uint64(b1 - 128) + case b1-128 == 126: + // Payload length is represented in the next 2 bytes (16-bit unsigned integer) + var l uint16 + if err := binary.Read(buf, binary.BigEndian, &l); err != nil { + return nil, err + } + payloadLength = uint64(l) + case b1-128 == 127: + // Payload length is represented in the next 8 bytes (64-bit unsigned integer) + if err := binary.Read(buf, binary.BigEndian, &payloadLength); err != nil { + return nil, err + } + } + + mask := make([]byte, 4) + if _, err := io.ReadFull(buf, mask); err != nil { + return nil, err + } + + payload := make([]byte, payloadLength) + if _, err := io.ReadFull(buf, payload); err != nil { + return nil, err + } + + for i, b := range payload { + payload[i] = b ^ mask[i%4] + } + + return &Frame{ + Fin: fin, + RSV1: rsv1, + RSV2: rsv2, + RSV3: rsv3, + Opcode: opcode, + Payload: payload, + }, nil +} + +func writeFrame(dst *bufio.ReadWriter, frame *Frame) error { + // FIN, RSV1-3, OPCODE + var b1 byte + if frame.Fin { + b1 |= 0b10000000 + } + if frame.RSV1 { + b1 |= 0b01000000 + } + if frame.RSV2 { + b1 |= 0b00100000 + } + if frame.RSV3 { + b1 |= 0b00010000 + } + b1 |= uint8(frame.Opcode) & 0b00001111 + if err := dst.WriteByte(b1); err != nil { + return err + } + + // payload length + payloadLen := int64(len(frame.Payload)) + switch { + case payloadLen <= 125: + if err := dst.WriteByte(byte(payloadLen)); err != nil { + return err + } + case payloadLen <= 65535: + if err := dst.WriteByte(126); err != nil { + return err + } + if err := binary.Write(dst, binary.BigEndian, uint16(payloadLen)); err != nil { + return err + } + default: + if err := dst.WriteByte(127); err != nil { + return err + } + if err := binary.Write(dst, binary.BigEndian, payloadLen); err != nil { + return err + } + } + + // payload + if _, err := dst.Write(frame.Payload); err != nil { + return err + } + + return dst.Flush() +} + +// writeCloseFrame writes a close frame to the wire, with an optional error +// message. +func writeCloseFrame(dst *bufio.ReadWriter, code StatusCode, err error) error { + var payload []byte + payload = binary.BigEndian.AppendUint16(payload, uint16(code)) + if err != nil { + payload = append(payload, []byte(err.Error())...) + } + return writeFrame(dst, &Frame{ + Fin: true, + Opcode: OpcodeClose, + Payload: payload, + }) +} + +// frameResponse splits a message into N frames with payloads of at most +// fragmentSize bytes. +func frameResponse(msg *Message, fragmentSize int) []*Frame { + var result []*Frame + + fin := false + opcode := OpcodeText + if msg.Binary { + opcode = OpcodeBinary + } + + offset := 0 + dataLen := len(msg.Payload) + for { + if offset > 0 { + opcode = OpcodeContinuation + } + end := offset + fragmentSize + if end >= dataLen { + fin = true + end = dataLen + } + result = append(result, &Frame{ + Fin: fin, + Opcode: opcode, + Payload: msg.Payload[offset:end], + }) + if fin { + break + } + } + return result +} + +var reservedStatusCodes = map[uint16]bool{ + // Explicitly reserved by RFC section 7.4.1 Defined Status Codes: + // https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1 + 1004: true, + 1005: true, + 1006: true, + 1015: true, + // Apparently reserved, according to the autobahn testsuite's fuzzingclient + // tests, though it's not clear to me why, based on the RFC. + // + // See: https://github.com/crossbario/autobahn-testsuite + 1016: true, + 1100: true, + 2000: true, + 2999: true, +} + +func validateFrame(frame *Frame, maxFragmentSize int) error { + // We do not support any extensions, per the spec all RSV bits must be 0: + // https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + if frame.RSV1 || frame.RSV2 || frame.RSV3 { + return fmt.Errorf("frame has unsupported RSV bits set") + } + + switch frame.Opcode { + case OpcodeContinuation, OpcodeText, OpcodeBinary: + if len(frame.Payload) > maxFragmentSize { + return fmt.Errorf("frame payload size %d exceeds maximum of %d bytes", len(frame.Payload), maxFragmentSize) + } + case OpcodeClose, OpcodePing, OpcodePong: + // All control frames MUST have a payload length of 125 bytes or less + // and MUST NOT be fragmented. + // https://datatracker.ietf.org/doc/html/rfc6455#section-5.5 + if len(frame.Payload) > 125 { + return fmt.Errorf("frame payload size %d exceeds 125 bytes", len(frame.Payload)) + } + if !frame.Fin { + return fmt.Errorf("control frame %v must not be fragmented", frame.Opcode) + } + } + + if frame.Opcode == OpcodeClose { + if len(frame.Payload) == 0 { + return nil + } + if len(frame.Payload) == 1 { + return fmt.Errorf("close frame payload must be at least 2 bytes") + } + + code := binary.BigEndian.Uint16(frame.Payload[:2]) + if code < 1000 || code >= 5000 { + return fmt.Errorf("close frame status code %d out of range", code) + } + if reservedStatusCodes[code] { + return fmt.Errorf("close frame status code %d is reserved", code) + } + + if len(frame.Payload) > 2 { + if !utf8.Valid(frame.Payload[2:]) { + return errors.New("close frame payload must be vaid UTF-8") + } + } + } + + return nil +} + +func acceptKey(clientKey string) string { + // Magic value comes from RFC 6455 section 1.3: Opening Handshake + // https://www.rfc-editor.org/rfc/rfc6455#section-1.3 + h := sha1.New() + io.WriteString(h, clientKey+"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/httpbin/websocket/websocket_autobahn_test.go b/httpbin/websocket/websocket_autobahn_test.go new file mode 100644 index 00000000..c1ae0c95 --- /dev/null +++ b/httpbin/websocket/websocket_autobahn_test.go @@ -0,0 +1,211 @@ +package websocket_test + +import ( + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "os" + "os/exec" + "path" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/mccutchen/go-httpbin/v2/httpbin/websocket" + "github.com/mccutchen/go-httpbin/v2/internal/testing/assert" +) + +const autobahnImage = "crossbario/autobahn-testsuite:0.8.2" + +var defaultIncludedTestCases = []string{ + "*", +} + +var defaultExcludedTestCases = []string{ + // These cases all seem to rely on the server accepting fragmented text + // frames with invalid utf8 payloads, but the spec seems to indicate that + // every text fragment must be valid utf8 on its own. + "6.2.3", + "6.2.4", + "6.4.2", + + // Compression extensions are not supported + "12.*", + "13.*", +} + +func TestWebSocketServer(t *testing.T) { + t.Parallel() + + if os.Getenv("AUTOBAHN_TESTS") == "" { + t.Skipf("set AUTOBAHN_TESTS=1 to run autobahn integration tests") + } + + includedTestCases := defaultIncludedTestCases + excludedTestCases := defaultExcludedTestCases + if userTestCases := os.Getenv("AUTOBAHN_CASES"); userTestCases != "" { + t.Logf("using AUTOBAHN_CASES=%q", userTestCases) + includedTestCases = strings.Split(userTestCases, ",") + excludedTestCases = []string{} + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws := websocket.New(w, r, websocket.Limits{ + MaxFragmentSize: 1024 * 1024 * 16, + MaxMessageSize: 1024 * 1024 * 16, + }) + if err := ws.Handshake(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + ws.Serve(websocket.EchoHandler) + })) + defer srv.Close() + + testDir := newTestDir(t) + t.Logf("test dir: %s", testDir) + + targetURL := newAutobahnTargetURL(t, srv) + t.Logf("target url: %s", targetURL) + + autobahnCfg := map[string]any{ + "servers": []map[string]string{ + { + "agent": "go-httpbin", + "url": targetURL, + }, + }, + "outdir": "/testdir/report", + "cases": includedTestCases, + "exclude-cases": excludedTestCases, + } + + autobahnCfgFile, err := os.Create(path.Join(testDir, "autobahn.json")) + assert.NilError(t, err) + assert.NilError(t, json.NewEncoder(autobahnCfgFile).Encode(autobahnCfg)) + autobahnCfgFile.Close() + + pullCmd := exec.Command("docker", "pull", autobahnImage) + runCmd(t, pullCmd) + + testCmd := exec.Command( + "docker", + "run", + "--net=host", + "--rm", + "-v", testDir+":/testdir:rw", + autobahnImage, + "wstest", "-m", "fuzzingclient", "--spec", "/testdir/autobahn.json", + ) + runCmd(t, testCmd) + + summary := loadSummary(t, testDir) + if len(summary) == 0 { + t.Fatalf("empty autobahn test summary; check autobahn logs for problems connecting to test server at %q", targetURL) + } + + for _, results := range summary { + for caseName, result := range results { + result := result + t.Run("autobahn/"+caseName, func(t *testing.T) { + if result.Behavior == "FAILED" || result.BehaviorClose == "FAILED" { + report := loadReport(t, testDir, result.ReportFile) + t.Errorf("description: %s", report.Description) + t.Errorf("expectation: %s", report.Expectation) + t.Errorf("result: %s", report.Result) + t.Errorf("close: %s", report.ResultClose) + } + }) + } + } + + t.Logf("autobahn test report: %s", path.Join(testDir, "report/index.html")) + if os.Getenv("AUTOBAHN_OPEN_REPORT") != "" { + runCmd(t, exec.Command("open", path.Join(testDir, "report/index.html"))) + } +} + +// newAutobahnTargetURL returns the URL that the autobahn test suite should use +// to connect to the given httptest server. +// +// On Macs, the docker engine is running inside an implicit VM, so even with +// --net=host, we need to use the special hostname to escape the VM. +// +// See the Docker Desktop docs[1] for more information. This same special +// hostname seems to work across Docker Desktop for Mac, OrbStack, and Colima. +// +// [1]: https://docs.docker.com/desktop/networking/#i-want-to-connect-from-a-container-to-a-service-on-the-host +func newAutobahnTargetURL(t *testing.T, srv *httptest.Server) string { + t.Helper() + u, err := url.Parse(srv.URL) + assert.NilError(t, err) + + var host string + switch runtime.GOOS { + case "darwin": + host = "host.docker.internal" + default: + host = "127.0.0.1" + } + + return fmt.Sprintf("ws://%s:%s/websocket/echo", host, u.Port()) +} + +func runCmd(t *testing.T, cmd *exec.Cmd) { + t.Helper() + t.Logf("running command: %s", cmd.String()) + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + assert.NilError(t, cmd.Run()) +} + +func newTestDir(t *testing.T) string { + t.Helper() + + // package tests are run with the package as the working directory, but we + // want to store our integration test output in the repo root + testDir, err := filepath.Abs(path.Join( + "..", "..", ".integrationtests", fmt.Sprintf("autobahn-test-%d", time.Now().Unix()), + )) + + assert.NilError(t, err) + assert.NilError(t, os.MkdirAll(testDir, 0o755)) + return testDir +} + +func loadSummary(t *testing.T, testDir string) autobahnReportSummary { + t.Helper() + f, err := os.Open(path.Join(testDir, "report", "index.json")) + assert.NilError(t, err) + defer f.Close() + var summary autobahnReportSummary + assert.NilError(t, json.NewDecoder(f).Decode(&summary)) + return summary +} + +func loadReport(t *testing.T, testDir string, reportFile string) autobahnReportResult { + t.Helper() + reportPath := path.Join(testDir, "report", reportFile) + t.Logf("report path: %s", reportPath) + f, err := os.Open(reportPath) + assert.NilError(t, err) + var report autobahnReportResult + assert.NilError(t, json.NewDecoder(f).Decode(&report)) + return report +} + +type autobahnReportSummary map[string]map[string]autobahnReportResult + +type autobahnReportResult struct { + Behavior string `json:"behavior"` + BehaviorClose string `json:"behaviorClose"` + Description string `json:"description"` + Expectation string `json:"expectation"` + ReportFile string `json:"reportfile"` + Result string `json:"result"` + ResultClose string `json:"resultClose"` +} diff --git a/httpbin/websocket/websocket_test.go b/httpbin/websocket/websocket_test.go new file mode 100644 index 00000000..cb1084e0 --- /dev/null +++ b/httpbin/websocket/websocket_test.go @@ -0,0 +1,246 @@ +package websocket_test + +import ( + "bufio" + "fmt" + "net" + "net/http" + "net/http/httptest" + "testing" + + "github.com/mccutchen/go-httpbin/v2/httpbin/websocket" + "github.com/mccutchen/go-httpbin/v2/internal/testing/assert" +) + +func TestHandshake(t *testing.T) { + testCases := map[string]struct { + reqHeaders map[string]string + wantStatus int + wantRespHeaders map[string]string + }{ + "valid handshake": { + reqHeaders: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + }, + wantRespHeaders: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + "Sec-Websocket-Accept": "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + }, + wantStatus: http.StatusSwitchingProtocols, + }, + "valid handshake, header values case insensitive": { + reqHeaders: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "WebSocket", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + }, + wantRespHeaders: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + "Sec-Websocket-Accept": "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", + }, + wantStatus: http.StatusSwitchingProtocols, + }, + "missing Connection header": { + reqHeaders: map[string]string{ + "Upgrade": "websocket", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + }, + wantStatus: http.StatusBadRequest, + }, + "incorrect Connection header": { + reqHeaders: map[string]string{ + "Connection": "close", + "Upgrade": "websocket", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + }, + wantStatus: http.StatusBadRequest, + }, + "missing Upgrade header": { + reqHeaders: map[string]string{ + "Connection": "Upgrade", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + }, + wantStatus: http.StatusBadRequest, + }, + "incorrect Upgrade header": { + reqHeaders: map[string]string{ + "Connection": "Upgrade", + "Upgrade": "http/2", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + }, + wantStatus: http.StatusBadRequest, + }, + "missing version": { + reqHeaders: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + }, + wantStatus: http.StatusBadRequest, + }, + "incorrect version": { + reqHeaders: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "12", + }, + wantStatus: http.StatusBadRequest, + }, + "missing Sec-WebSocket-Key": { + reqHeaders: map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Version": "13", + }, + wantStatus: http.StatusBadRequest, + }, + } + for name, tc := range testCases { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws := websocket.New(w, r, websocket.Limits{}) + if err := ws.Handshake(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + ws.Serve(websocket.EchoHandler) + })) + defer srv.Close() + + req, _ := http.NewRequest(http.MethodGet, srv.URL, nil) + for k, v := range tc.reqHeaders { + req.Header.Set(k, v) + } + + resp, err := http.DefaultClient.Do(req) + assert.NilError(t, err) + + assert.StatusCode(t, resp, tc.wantStatus) + for k, v := range tc.wantRespHeaders { + assert.Equal(t, resp.Header.Get(k), v, "incorrect value for %q response header", k) + } + }) + } +} + +func TestHandshakeOrder(t *testing.T) { + handshakeReq := httptest.NewRequest(http.MethodGet, "/websocket/echo", nil) + for k, v := range map[string]string{ + "Connection": "upgrade", + "Upgrade": "websocket", + "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==", + "Sec-WebSocket-Version": "13", + } { + handshakeReq.Header.Set(k, v) + } + + t.Run("double handshake", func(t *testing.T) { + w := httptest.NewRecorder() + ws := websocket.New(w, handshakeReq, websocket.Limits{}) + + // first handshake succeeds + assert.NilError(t, ws.Handshake()) + assert.Equal(t, w.Code, http.StatusSwitchingProtocols, "incorrect status code") + + // second handshake fails + defer func() { + r := recover() + if r == nil { + t.Fatalf("expected to catch panic on double handshake") + } + assert.Equal(t, fmt.Sprint(r), "websocket: handshake already completed", "incorrect panic message") + }() + ws.Handshake() + }) + + t.Run("handshake not completed", func(t *testing.T) { + defer func() { + r := recover() + if r == nil { + t.Fatalf("expected to catch panic on Serve before Handshake") + } + assert.Equal(t, fmt.Sprint(r), "websocket: serve: handshake not completed", "incorrect panic message") + }() + w := httptest.NewRecorder() + websocket.New(w, handshakeReq, websocket.Limits{}).Serve(nil) + }) + + t.Run("http.Hijack not implemented", func(t *testing.T) { + // confirm that httptest.ResponseRecorder does not implmeent + // http.Hjijacker + var rw http.ResponseWriter = httptest.NewRecorder() + _, ok := rw.(http.Hijacker) + assert.Equal(t, ok, false, "expected httptest.ResponseRecorder not to implement http.Hijacker") + + w := httptest.NewRecorder() + ws := websocket.New(w, handshakeReq, websocket.Limits{}) + + assert.NilError(t, ws.Handshake()) + assert.Equal(t, w.Code, http.StatusSwitchingProtocols, "incorrect status code") + + defer func() { + r := recover() + if r == nil { + t.Fatalf("expected to catch panic on when http.Hijack not implemented") + } + assert.Equal(t, fmt.Sprint(r), "websocket: serve: server does not support hijacking", "incorrect panic message") + }() + ws.Serve(nil) + }) + + t.Run("hijack failed", func(t *testing.T) { + w := &brokenHijackResponseWriter{} + ws := websocket.New(w, handshakeReq, websocket.Limits{}) + + assert.NilError(t, ws.Handshake()) + assert.Equal(t, w.Code, http.StatusSwitchingProtocols, "incorrect status code") + + defer func() { + r := recover() + if r == nil { + t.Fatalf("expected to catch panic on Serve before Handshake") + } + assert.Equal(t, fmt.Sprint(r), "websocket: serve: hijack failed: error hijacking connection", "incorrect panic message") + }() + ws.Serve(nil) + }) +} + +// brokenHijackResponseWriter implements just enough to satisfy the +// http.ResponseWriter and http.Hijacker interfaces and get through the +// handshake before failing to actually hijack the connection. +type brokenHijackResponseWriter struct { + http.ResponseWriter + Code int +} + +func (w *brokenHijackResponseWriter) WriteHeader(code int) { + w.Code = code +} + +func (w *brokenHijackResponseWriter) Header() http.Header { + return http.Header{} +} + +func (brokenHijackResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return nil, nil, fmt.Errorf("error hijacking connection") +} + +var ( + _ http.ResponseWriter = &brokenHijackResponseWriter{} + _ http.Hijacker = &brokenHijackResponseWriter{} +)