diff --git a/.gitignore b/.gitignore
index a3b9b4e0..348636d5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,27 +1,6 @@
-# Compiled Object files, Static and Dynamic libs (Shared Objects)
-*.o
-*.a
-*.so
-
-# Folders
-_obj
-_test
-
-# Architecture specific extensions/prefixes
-*.[568vq]
-[568vq].out
-
-*.cgo1.go
-*.cgo2.c
-_cgo_defun.c
-_cgo_gotypes.go
-_cgo_export.*
-
-_testmain.go
-
+.*
*.exe
*.test
*.prof
-
dist/*
coverage.txt
diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md
new file mode 100644
index 00000000..27265e31
--- /dev/null
+++ b/DEVELOPMENT.md
@@ -0,0 +1,113 @@
+# Development
+
+## Local development
+
+For interactive local development, use `make run` to build and run go-httpbin
+or `make watch` to automatically re-build and re-run go-httpbin on every
+change:
+
+ make run
+ make watch
+
+By default, the server will listen on `http://127.0.0.1:8080`, but the host,
+port, or any other [configuration option][config] may be overridden by
+specifying the relevant environment variables:
+
+ make run PORT=9999
+ make run PORT=9999 MAX_DURATION=60s
+ make watch HOST=0.0.0.0 PORT=8888
+
+## Testing
+
+Run `make test` to run unit tests, using `TEST_ARGS` to pass arguments through
+to `go test`:
+
+ make test
+ make test TEST_ARGS="-v -race -run ^TestDelay"
+
+### Integration tests
+
+go-httpbin includes its own minimal WebSocket echo server implementation, and
+we use the incredibly helpful [Autobahn Testsuite][] to ensure that the
+implementation conforms to the spec.
+
+These tests can be slow to run (~40 seconds on my machine), so they are not run
+by default when using `make test`.
+
+They are run automatically as part of our extended "CI" test suite, which is
+run on every pull request:
+
+ make testci
+
+### WebSocket development
+
+When working on the WebSocket implementation, it can also be useful to run
+those integration tests directly, like so:
+
+ make testautobahn
+
+Use the `AUTOBAHN_CASES` var to run a specific subset of the Autobahn tests,
+which may or may not include wildcards:
+
+ make testautobahn AUTOBAHN_CASES=6.*
+ make testautobahn AUTOBAHN_CASES=6.5.*
+ make testautobahn AUTOBAHN_CASES=6.5.4
+
+
+### Test coverage
+
+We use [Codecov][] to measure and track test coverage as part of our continuous
+integration test suite. While we strive for as much coverage as possible and
+the Codecov CI check is configured with fairly strict requirements, 100% test
+coverage is not an explicit goal or requirement for all contributions.
+
+To view test coverage locally, use
+
+ make testcover
+
+which will run the full suite of unit and integration tests and pop open a web
+browser to view coverage results.
+
+
+## Linting and code style
+
+Run `make lint` to run our suite of linters and formatters, which include
+gofmt, [revive][], and [staticcheck][]:
+
+ make lint
+
+
+## Docker images
+
+To build a docker image locally:
+
+ make image
+
+To build a docker image an push it to a remote repository:
+
+ make imagepush
+
+By default, images will be tagged as `mccutchen/go-httpbin:${COMMIT}` with the
+current HEAD commit hash.
+
+Use `VERSION` to override the tag value
+
+ make imagepush VERSION=v1.2.3
+
+or `DOCKER_TAG` to override the remote repo and version at once:
+
+ make imagepush DOCKER_TAG=my-org/my-fork:v1.2.3
+
+### Automated docker image builds
+
+When a new release is created, the [Release][] GitHub Actions workflow
+automatically builds and pushes new Docker images for both linux/amd64 and
+linux/arm64 architectures.
+
+
+[config]: /README.md#configuration
+[revive]: https://github.com/mgechev/revive
+[staticcheck]: https://staticcheck.dev/
+[Release]: /.github/workflows/release.yaml
+[Codecov]: https://app.codecov.io/gh/mccutchen/go-httpbin
+[Autobahn Testsuite]: https://github.com/crossbario/autobahn-testsuite
diff --git a/Makefile b/Makefile
index 033d7655..ad3f9c7e 100644
--- a/Makefile
+++ b/Makefile
@@ -38,7 +38,7 @@ buildtests:
.PHONY: buildtests
clean:
- rm -rf $(DIST_PATH) $(COVERAGE_PATH)
+ rm -rf $(DIST_PATH) $(COVERAGE_PATH) .integrationtests
.PHONY: clean
@@ -53,14 +53,18 @@ test:
# based on codecov.io's documentation:
# https://github.com/codecov/example-go/blob/b85638743b972bd0bd2af63421fe513c6f968930/README.md
testci: build buildexamples
- go test $(TEST_ARGS) $(COVERAGE_ARGS) ./...
- git diff --exit-code
+ AUTOBAHN_TESTS=1 go test $(TEST_ARGS) $(COVERAGE_ARGS) ./...
.PHONY: testci
testcover: testci
go tool cover -html=$(COVERAGE_PATH)
.PHONY: testcover
+# Run the autobahn fuzzingclient test suite
+testautobahn:
+ AUTOBAHN_TESTS=1 AUTOBAHN_OPEN_REPORT=1 go test -v -run ^TestWebSocketServer$$ $(TEST_ARGS) ./...
+.PHONY: autobahntests
+
lint:
test -z "$$(gofmt -d -s -e .)" || (echo "Error: gofmt failed"; gofmt -d -s -e . ; exit 1)
go vet ./...
diff --git a/README.md b/README.md
index 9d3d9a57..58f0c815 100644
--- a/README.md
+++ b/README.md
@@ -169,17 +169,7 @@ public internet, consider tuning it appropriately:
## Development
-```bash
-# local development
-make
-make test
-make testcover
-make run
-
-# building & pushing docker images
-make image
-make imagepush
-```
+See [DEVELOPMENT.md][].
## Motivation & prior art
@@ -218,3 +208,4 @@ Compared to [ahmetb/go-httpbin][ahmet]:
[Observer]: https://pkg.go.dev/github.com/mccutchen/go-httpbin/v2/httpbin#Observer
[Production considerations]: #production-considerations
[zerolog]: https://github.com/rs/zerolog
+[DEVELOPMENT.md]: ./DEVELOPMENT.md
diff --git a/httpbin/handlers.go b/httpbin/handlers.go
index a122a016..7ecb7bc3 100644
--- a/httpbin/handlers.go
+++ b/httpbin/handlers.go
@@ -15,6 +15,7 @@ import (
"time"
"github.com/mccutchen/go-httpbin/v2/httpbin/digest"
+ "github.com/mccutchen/go-httpbin/v2/httpbin/websocket"
)
var nilValues = url.Values{}
@@ -1112,3 +1113,51 @@ func (h *HTTPBin) Hostname(w http.ResponseWriter, _ *http.Request) {
Hostname: h.hostname,
})
}
+
+// WebSocketEcho - simple websocket echo server, where the max fragment size
+// and max message size can be controlled by clients.
+func (h *HTTPBin) WebSocketEcho(w http.ResponseWriter, r *http.Request) {
+ var (
+ maxFragmentSize = h.MaxBodySize / 2
+ maxMessageSize = h.MaxBodySize
+ q = r.URL.Query()
+ err error
+ )
+
+ if userMaxFragmentSize := q.Get("max_fragment_size"); userMaxFragmentSize != "" {
+ maxFragmentSize, err = strconv.ParseInt(userMaxFragmentSize, 10, 32)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, fmt.Errorf("invalid max_fragment_size: %w", err))
+ return
+ } else if maxFragmentSize < 1 || maxFragmentSize > h.MaxBodySize {
+ writeError(w, http.StatusBadRequest, fmt.Errorf("invalid max_fragment_size: %d not in range [1, %d]", maxFragmentSize, h.MaxBodySize))
+ return
+ }
+ }
+
+ if userMaxMessageSize := q.Get("max_message_size"); userMaxMessageSize != "" {
+ maxMessageSize, err = strconv.ParseInt(userMaxMessageSize, 10, 32)
+ if err != nil {
+ writeError(w, http.StatusBadRequest, fmt.Errorf("invalid max_message_size: %w", err))
+ return
+ } else if maxMessageSize < 1 || maxMessageSize > h.MaxBodySize {
+ writeError(w, http.StatusBadRequest, fmt.Errorf("invalid max_message_size: %d not in range [1, %d]", maxMessageSize, h.MaxBodySize))
+ return
+ }
+ }
+
+ if maxFragmentSize > maxMessageSize {
+ writeError(w, http.StatusBadRequest, fmt.Errorf("max_fragment_size %d must be less than or equal to max_message_size %d", maxFragmentSize, maxMessageSize))
+ return
+ }
+
+ ws := websocket.New(w, r, websocket.Limits{
+ MaxFragmentSize: int(maxFragmentSize),
+ MaxMessageSize: int(maxMessageSize),
+ })
+ if err := ws.Handshake(); err != nil {
+ writeError(w, http.StatusBadRequest, err)
+ return
+ }
+ ws.Serve(websocket.EchoHandler)
+}
diff --git a/httpbin/handlers_test.go b/httpbin/handlers_test.go
index fc5387d8..22f85afb 100644
--- a/httpbin/handlers_test.go
+++ b/httpbin/handlers_test.go
@@ -2911,6 +2911,78 @@ func TestHostname(t *testing.T) {
})
}
+func TestWebSocketEcho(t *testing.T) {
+ // ========================================================================
+ // Note: Here we only test input validation for the websocket endpoint.
+ //
+ // See websocket/*_test.go for in-depth integration tests of the actual
+ // websocket implementation.
+ // ========================================================================
+
+ handshakeHeaders := map[string]string{
+ "Connection": "upgrade",
+ "Upgrade": "websocket",
+ "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
+ "Sec-WebSocket-Version": "13",
+ }
+
+ t.Run("handshake ok", func(t *testing.T) {
+ t.Parallel()
+
+ req := newTestRequest(t, http.MethodGet, "/websocket/echo")
+ for k, v := range handshakeHeaders {
+ req.Header.Set(k, v)
+ }
+
+ resp, err := client.Do(req)
+ assert.NilError(t, err)
+ assert.StatusCode(t, resp, http.StatusSwitchingProtocols)
+ })
+
+ t.Run("handshake failed", func(t *testing.T) {
+ t.Parallel()
+ req := newTestRequest(t, http.MethodGet, "/websocket/echo")
+ resp, err := client.Do(req)
+ assert.NilError(t, err)
+ assert.StatusCode(t, resp, http.StatusBadRequest)
+ })
+
+ paramTests := []struct {
+ query string
+ wantStatus int
+ }{
+ // ok
+ {"max_fragment_size=1&max_message_size=2", http.StatusSwitchingProtocols},
+ {fmt.Sprintf("max_fragment_size=%d&max_message_size=%d", app.MaxBodySize, app.MaxBodySize), http.StatusSwitchingProtocols},
+
+ // bad max_framgent_size
+ {"max_fragment_size=-1&max_message_size=2", http.StatusBadRequest},
+ {"max_fragment_size=0&max_message_size=2", http.StatusBadRequest},
+ {"max_fragment_size=3&max_message_size=2", http.StatusBadRequest},
+ {"max_fragment_size=foo&max_message_size=2", http.StatusBadRequest},
+ {fmt.Sprintf("max_fragment_size=%d&max_message_size=2", app.MaxBodySize+1), http.StatusBadRequest},
+
+ // bad max_message_size
+ {"max_fragment_size=1&max_message_size=0", http.StatusBadRequest},
+ {"max_fragment_size=1&max_message_size=-1", http.StatusBadRequest},
+ {"max_fragment_size=1&max_message_size=bar", http.StatusBadRequest},
+ {fmt.Sprintf("max_fragment_size=1&max_message_size=%d", app.MaxBodySize+1), http.StatusBadRequest},
+ }
+ for _, tc := range paramTests {
+ tc := tc
+ t.Run(tc.query, func(t *testing.T) {
+ t.Parallel()
+ req := newTestRequest(t, http.MethodGet, "/websocket/echo?"+tc.query)
+ for k, v := range handshakeHeaders {
+ req.Header.Set(k, v)
+ }
+ resp, err := client.Do(req)
+ assert.NilError(t, err)
+ assert.StatusCode(t, resp, tc.wantStatus)
+ })
+ }
+}
+
func newTestServer(handler http.Handler) (*httptest.Server, *http.Client) {
srv := httptest.NewServer(handler)
client := srv.Client()
diff --git a/httpbin/httpbin.go b/httpbin/httpbin.go
index 81315fb6..b1b420a3 100644
--- a/httpbin/httpbin.go
+++ b/httpbin/httpbin.go
@@ -153,6 +153,8 @@ func (h *HTTPBin) Handler() http.Handler {
mux.HandleFunc("/dump/request", h.DumpRequest)
+ mux.HandleFunc("/websocket/echo", h.WebSocketEcho)
+
// existing httpbin endpoints that we do not support
mux.HandleFunc("/brotli", notImplementedHandler)
diff --git a/httpbin/middleware.go b/httpbin/middleware.go
index 15d507ae..02d070a1 100644
--- a/httpbin/middleware.go
+++ b/httpbin/middleware.go
@@ -1,8 +1,10 @@
package httpbin
import (
+ "bufio"
"fmt"
"log"
+ "net"
"net/http"
"time"
)
@@ -123,6 +125,10 @@ func (mw *metaResponseWriter) Size() int64 {
return mw.size
}
+func (mw *metaResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return mw.w.(http.Hijacker).Hijack()
+}
+
func observe(o Observer, h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
mw := &metaResponseWriter{w: w}
diff --git a/httpbin/static/index.html b/httpbin/static/index.html
index 1599af8b..4d707eea 100644
--- a/httpbin/static/index.html
+++ b/httpbin/static/index.html
@@ -115,6 +115,7 @@
ENDPOINTS
/unstable
Fails half the time, accepts optional failure_rate float and seed integer parameters.
/user-agent
Returns user-agent.
/uuid
Generates a UUIDv4 value.
+/websocket/echo?max_fragment_size=2048&max_message_size=10240
A WebSocket echo service.
/xml
Returns some XML
diff --git a/httpbin/websocket/websocket.go b/httpbin/websocket/websocket.go
new file mode 100644
index 00000000..e85f79a4
--- /dev/null
+++ b/httpbin/websocket/websocket.go
@@ -0,0 +1,474 @@
+// Package websocket implements a basic websocket server.
+package websocket
+
+import (
+ "bufio"
+ "context"
+ "crypto/sha1"
+ "encoding/base64"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "strings"
+ "unicode/utf8"
+)
+
+const requiredVersion = "13"
+
+// Opcode is a websocket OPCODE.
+type Opcode uint8
+
+// See the RFC for the set of defined opcodes:
+// https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
+const (
+ OpcodeContinuation Opcode = 0x0
+ OpcodeText Opcode = 0x1
+ OpcodeBinary Opcode = 0x2
+ OpcodeClose Opcode = 0x8
+ OpcodePing Opcode = 0x9
+ OpcodePong Opcode = 0xA
+)
+
+// StatusCode is a websocket status code.
+type StatusCode uint16
+
+// See the RFC for the set of defined status codes:
+// https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1
+const (
+ StatusNormalClosure StatusCode = 1000
+ StatusGoingAway StatusCode = 1001
+ StatusProtocolError StatusCode = 1002
+ StatusUnsupported StatusCode = 1003
+ StatusNoStatusRcvd StatusCode = 1005
+ StatusAbnormalClose StatusCode = 1006
+ StatusUnsupportedPayload StatusCode = 1007
+ StatusPolicyViolation StatusCode = 1008
+ StatusTooLarge StatusCode = 1009
+ StatusTlSHandshake StatusCode = 1015
+ StatusServerError StatusCode = 1011
+)
+
+// Frame is a websocket protocol frame.
+type Frame struct {
+ Fin bool
+ RSV1 bool
+ RSV3 bool
+ RSV2 bool
+ Opcode Opcode
+ Payload []byte
+}
+
+// Message is an application-level message from the client, which may be
+// constructed from one or more individual protocol frames.
+type Message struct {
+ Binary bool
+ Payload []byte
+}
+
+// Handler handles a single websocket message. If the returned message is
+// non-nil, it will be sent to the client. If an error is returned, the
+// connection will be closed.
+type Handler func(ctx context.Context, msg *Message) (*Message, error)
+
+// EchoHandler is a Handler that echoes each incoming message back to the
+// client.
+var EchoHandler Handler = func(ctx context.Context, msg *Message) (*Message, error) {
+ return msg, nil
+}
+
+// Limits define the limits imposed on a websocket connection.
+type Limits struct {
+ MaxFragmentSize int
+ MaxMessageSize int
+}
+
+// WebSocket is a websocket connection.
+type WebSocket struct {
+ w http.ResponseWriter
+ r *http.Request
+ maxFragmentSize int
+ maxMessageSize int
+ handshook bool
+}
+
+// New creates a new websocket.
+func New(w http.ResponseWriter, r *http.Request, limits Limits) *WebSocket {
+ return &WebSocket{
+ w: w,
+ r: r,
+ maxFragmentSize: limits.MaxFragmentSize,
+ maxMessageSize: limits.MaxMessageSize,
+ }
+}
+
+// Handshake validates the request and performs the WebSocket handshake. If
+// Handshake returns nil, only websocket frames should be written to the
+// response writer.
+func (s *WebSocket) Handshake() error {
+ if s.handshook {
+ panic("websocket: handshake already completed")
+ }
+
+ if strings.ToLower(s.r.Header.Get("Connection")) != "upgrade" {
+ return fmt.Errorf("missing required `Connection: upgrade` header")
+ }
+ if strings.ToLower(s.r.Header.Get("Upgrade")) != "websocket" {
+ return fmt.Errorf("missing required `Upgrade: websocket` header")
+ }
+ if v := s.r.Header.Get("Sec-Websocket-Version"); v != requiredVersion {
+ return fmt.Errorf("only websocket version %q is supported, got %q", requiredVersion, v)
+ }
+
+ clientKey := s.r.Header.Get("Sec-Websocket-Key")
+ if clientKey == "" {
+ return fmt.Errorf("missing required `Sec-Websocket-Key` header")
+ }
+
+ s.w.Header().Set("Connection", "upgrade")
+ s.w.Header().Set("Upgrade", "websocket")
+ s.w.Header().Set("Sec-Websocket-Accept", acceptKey(clientKey))
+ s.w.WriteHeader(http.StatusSwitchingProtocols)
+
+ s.handshook = true
+ return nil
+}
+
+// Serve handles a websocket connection after the handshake has been completed.
+func (s *WebSocket) Serve(handler Handler) {
+ if !s.handshook {
+ panic("websocket: serve: handshake not completed")
+ }
+
+ hj, ok := s.w.(http.Hijacker)
+ if !ok {
+ panic("websocket: serve: server does not support hijacking")
+ }
+
+ conn, buf, err := hj.Hijack()
+ if err != nil {
+ panic(fmt.Errorf("websocket: serve: hijack failed: %s", err))
+ }
+ defer conn.Close()
+
+ // errors intentionally ignored here. it's serverLoop's responsibility to
+ // properly close the websocket connection with a useful error message, and
+ // any unexpected error returned from serverLoop is not actionable.
+ _ = s.serveLoop(s.r.Context(), buf, handler)
+}
+
+func (s *WebSocket) serveLoop(ctx context.Context, buf *bufio.ReadWriter, handler Handler) error {
+ var currentMsg *Message
+
+ for {
+ select {
+ case <-ctx.Done():
+ return nil
+ default:
+ }
+
+ frame, err := nextFrame(buf)
+ if err != nil {
+ return writeCloseFrame(buf, StatusServerError, err)
+ }
+
+ if err := validateFrame(frame, s.maxFragmentSize); err != nil {
+ return writeCloseFrame(buf, StatusProtocolError, err)
+ }
+
+ switch frame.Opcode {
+ case OpcodeBinary, OpcodeText:
+ if currentMsg != nil {
+ return writeCloseFrame(buf, StatusProtocolError, errors.New("expected continuation frame"))
+ }
+ if frame.Opcode == OpcodeText && !utf8.Valid(frame.Payload) {
+ return writeCloseFrame(buf, StatusUnsupportedPayload, errors.New("invalid UTF-8"))
+ }
+ currentMsg = &Message{
+ Binary: frame.Opcode == OpcodeBinary,
+ Payload: frame.Payload,
+ }
+ case OpcodeContinuation:
+ if currentMsg == nil {
+ return writeCloseFrame(buf, StatusProtocolError, errors.New("unexpected continuation frame"))
+ }
+ if !currentMsg.Binary && !utf8.Valid(frame.Payload) {
+ return writeCloseFrame(buf, StatusUnsupportedPayload, errors.New("invalid UTF-8"))
+ }
+ currentMsg.Payload = append(currentMsg.Payload, frame.Payload...)
+ if len(currentMsg.Payload) > s.maxMessageSize {
+ return writeCloseFrame(buf, StatusTooLarge, fmt.Errorf("message size %d exceeds maximum of %d bytes", len(currentMsg.Payload), s.maxMessageSize))
+ }
+ case OpcodeClose:
+ return writeCloseFrame(buf, StatusNormalClosure, nil)
+ case OpcodePing:
+ frame.Opcode = OpcodePong
+ if err := writeFrame(buf, frame); err != nil {
+ return err
+ }
+ continue
+ case OpcodePong:
+ continue
+ default:
+ return writeCloseFrame(buf, StatusProtocolError, fmt.Errorf("unsupported opcode: %v", frame.Opcode))
+ }
+
+ if frame.Fin {
+ resp, err := handler(ctx, currentMsg)
+ if err != nil {
+ return writeCloseFrame(buf, StatusServerError, err)
+ }
+ if resp == nil {
+ continue
+ }
+ for _, respFrame := range frameResponse(resp, s.maxFragmentSize) {
+ if err := writeFrame(buf, respFrame); err != nil {
+ return err
+ }
+ }
+ currentMsg = nil
+ }
+ }
+}
+
+func nextFrame(buf *bufio.ReadWriter) (*Frame, error) {
+ bb := make([]byte, 2)
+ if _, err := io.ReadFull(buf, bb); err != nil {
+ return nil, err
+ }
+
+ b0 := bb[0]
+ b1 := bb[1]
+
+ var (
+ fin = b0&0b10000000 != 0
+ rsv1 = b0&0b01000000 != 0
+ rsv2 = b0&0b00100000 != 0
+ rsv3 = b0&0b00010000 != 0
+ opcode = Opcode(b0 & 0b00001111)
+ )
+
+ // Per https://datatracker.ietf.org/doc/html/rfc6455#section-5.2, all
+ // client frames must be masked.
+ if masked := b1 & 0b10000000; masked == 0 {
+ return nil, fmt.Errorf("received unmasked client frame")
+ }
+
+ var payloadLength uint64
+ switch {
+ case b1-128 <= 125:
+ // Payload length is directly represented in the second byte
+ payloadLength = uint64(b1 - 128)
+ case b1-128 == 126:
+ // Payload length is represented in the next 2 bytes (16-bit unsigned integer)
+ var l uint16
+ if err := binary.Read(buf, binary.BigEndian, &l); err != nil {
+ return nil, err
+ }
+ payloadLength = uint64(l)
+ case b1-128 == 127:
+ // Payload length is represented in the next 8 bytes (64-bit unsigned integer)
+ if err := binary.Read(buf, binary.BigEndian, &payloadLength); err != nil {
+ return nil, err
+ }
+ }
+
+ mask := make([]byte, 4)
+ if _, err := io.ReadFull(buf, mask); err != nil {
+ return nil, err
+ }
+
+ payload := make([]byte, payloadLength)
+ if _, err := io.ReadFull(buf, payload); err != nil {
+ return nil, err
+ }
+
+ for i, b := range payload {
+ payload[i] = b ^ mask[i%4]
+ }
+
+ return &Frame{
+ Fin: fin,
+ RSV1: rsv1,
+ RSV2: rsv2,
+ RSV3: rsv3,
+ Opcode: opcode,
+ Payload: payload,
+ }, nil
+}
+
+func writeFrame(dst *bufio.ReadWriter, frame *Frame) error {
+ // FIN, RSV1-3, OPCODE
+ var b1 byte
+ if frame.Fin {
+ b1 |= 0b10000000
+ }
+ if frame.RSV1 {
+ b1 |= 0b01000000
+ }
+ if frame.RSV2 {
+ b1 |= 0b00100000
+ }
+ if frame.RSV3 {
+ b1 |= 0b00010000
+ }
+ b1 |= uint8(frame.Opcode) & 0b00001111
+ if err := dst.WriteByte(b1); err != nil {
+ return err
+ }
+
+ // payload length
+ payloadLen := int64(len(frame.Payload))
+ switch {
+ case payloadLen <= 125:
+ if err := dst.WriteByte(byte(payloadLen)); err != nil {
+ return err
+ }
+ case payloadLen <= 65535:
+ if err := dst.WriteByte(126); err != nil {
+ return err
+ }
+ if err := binary.Write(dst, binary.BigEndian, uint16(payloadLen)); err != nil {
+ return err
+ }
+ default:
+ if err := dst.WriteByte(127); err != nil {
+ return err
+ }
+ if err := binary.Write(dst, binary.BigEndian, payloadLen); err != nil {
+ return err
+ }
+ }
+
+ // payload
+ if _, err := dst.Write(frame.Payload); err != nil {
+ return err
+ }
+
+ return dst.Flush()
+}
+
+// writeCloseFrame writes a close frame to the wire, with an optional error
+// message.
+func writeCloseFrame(dst *bufio.ReadWriter, code StatusCode, err error) error {
+ var payload []byte
+ payload = binary.BigEndian.AppendUint16(payload, uint16(code))
+ if err != nil {
+ payload = append(payload, []byte(err.Error())...)
+ }
+ return writeFrame(dst, &Frame{
+ Fin: true,
+ Opcode: OpcodeClose,
+ Payload: payload,
+ })
+}
+
+// frameResponse splits a message into N frames with payloads of at most
+// fragmentSize bytes.
+func frameResponse(msg *Message, fragmentSize int) []*Frame {
+ var result []*Frame
+
+ fin := false
+ opcode := OpcodeText
+ if msg.Binary {
+ opcode = OpcodeBinary
+ }
+
+ offset := 0
+ dataLen := len(msg.Payload)
+ for {
+ if offset > 0 {
+ opcode = OpcodeContinuation
+ }
+ end := offset + fragmentSize
+ if end >= dataLen {
+ fin = true
+ end = dataLen
+ }
+ result = append(result, &Frame{
+ Fin: fin,
+ Opcode: opcode,
+ Payload: msg.Payload[offset:end],
+ })
+ if fin {
+ break
+ }
+ }
+ return result
+}
+
+var reservedStatusCodes = map[uint16]bool{
+ // Explicitly reserved by RFC section 7.4.1 Defined Status Codes:
+ // https://datatracker.ietf.org/doc/html/rfc6455#section-7.4.1
+ 1004: true,
+ 1005: true,
+ 1006: true,
+ 1015: true,
+ // Apparently reserved, according to the autobahn testsuite's fuzzingclient
+ // tests, though it's not clear to me why, based on the RFC.
+ //
+ // See: https://github.com/crossbario/autobahn-testsuite
+ 1016: true,
+ 1100: true,
+ 2000: true,
+ 2999: true,
+}
+
+func validateFrame(frame *Frame, maxFragmentSize int) error {
+ // We do not support any extensions, per the spec all RSV bits must be 0:
+ // https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
+ if frame.RSV1 || frame.RSV2 || frame.RSV3 {
+ return fmt.Errorf("frame has unsupported RSV bits set")
+ }
+
+ switch frame.Opcode {
+ case OpcodeContinuation, OpcodeText, OpcodeBinary:
+ if len(frame.Payload) > maxFragmentSize {
+ return fmt.Errorf("frame payload size %d exceeds maximum of %d bytes", len(frame.Payload), maxFragmentSize)
+ }
+ case OpcodeClose, OpcodePing, OpcodePong:
+ // All control frames MUST have a payload length of 125 bytes or less
+ // and MUST NOT be fragmented.
+ // https://datatracker.ietf.org/doc/html/rfc6455#section-5.5
+ if len(frame.Payload) > 125 {
+ return fmt.Errorf("frame payload size %d exceeds 125 bytes", len(frame.Payload))
+ }
+ if !frame.Fin {
+ return fmt.Errorf("control frame %v must not be fragmented", frame.Opcode)
+ }
+ }
+
+ if frame.Opcode == OpcodeClose {
+ if len(frame.Payload) == 0 {
+ return nil
+ }
+ if len(frame.Payload) == 1 {
+ return fmt.Errorf("close frame payload must be at least 2 bytes")
+ }
+
+ code := binary.BigEndian.Uint16(frame.Payload[:2])
+ if code < 1000 || code >= 5000 {
+ return fmt.Errorf("close frame status code %d out of range", code)
+ }
+ if reservedStatusCodes[code] {
+ return fmt.Errorf("close frame status code %d is reserved", code)
+ }
+
+ if len(frame.Payload) > 2 {
+ if !utf8.Valid(frame.Payload[2:]) {
+ return errors.New("close frame payload must be vaid UTF-8")
+ }
+ }
+ }
+
+ return nil
+}
+
+func acceptKey(clientKey string) string {
+ // Magic value comes from RFC 6455 section 1.3: Opening Handshake
+ // https://www.rfc-editor.org/rfc/rfc6455#section-1.3
+ h := sha1.New()
+ io.WriteString(h, clientKey+"258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
+ return base64.StdEncoding.EncodeToString(h.Sum(nil))
+}
diff --git a/httpbin/websocket/websocket_autobahn_test.go b/httpbin/websocket/websocket_autobahn_test.go
new file mode 100644
index 00000000..c1ae0c95
--- /dev/null
+++ b/httpbin/websocket/websocket_autobahn_test.go
@@ -0,0 +1,211 @@
+package websocket_test
+
+import (
+ "encoding/json"
+ "fmt"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "os/exec"
+ "path"
+ "path/filepath"
+ "runtime"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/mccutchen/go-httpbin/v2/httpbin/websocket"
+ "github.com/mccutchen/go-httpbin/v2/internal/testing/assert"
+)
+
+const autobahnImage = "crossbario/autobahn-testsuite:0.8.2"
+
+var defaultIncludedTestCases = []string{
+ "*",
+}
+
+var defaultExcludedTestCases = []string{
+ // These cases all seem to rely on the server accepting fragmented text
+ // frames with invalid utf8 payloads, but the spec seems to indicate that
+ // every text fragment must be valid utf8 on its own.
+ "6.2.3",
+ "6.2.4",
+ "6.4.2",
+
+ // Compression extensions are not supported
+ "12.*",
+ "13.*",
+}
+
+func TestWebSocketServer(t *testing.T) {
+ t.Parallel()
+
+ if os.Getenv("AUTOBAHN_TESTS") == "" {
+ t.Skipf("set AUTOBAHN_TESTS=1 to run autobahn integration tests")
+ }
+
+ includedTestCases := defaultIncludedTestCases
+ excludedTestCases := defaultExcludedTestCases
+ if userTestCases := os.Getenv("AUTOBAHN_CASES"); userTestCases != "" {
+ t.Logf("using AUTOBAHN_CASES=%q", userTestCases)
+ includedTestCases = strings.Split(userTestCases, ",")
+ excludedTestCases = []string{}
+ }
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ws := websocket.New(w, r, websocket.Limits{
+ MaxFragmentSize: 1024 * 1024 * 16,
+ MaxMessageSize: 1024 * 1024 * 16,
+ })
+ if err := ws.Handshake(); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ ws.Serve(websocket.EchoHandler)
+ }))
+ defer srv.Close()
+
+ testDir := newTestDir(t)
+ t.Logf("test dir: %s", testDir)
+
+ targetURL := newAutobahnTargetURL(t, srv)
+ t.Logf("target url: %s", targetURL)
+
+ autobahnCfg := map[string]any{
+ "servers": []map[string]string{
+ {
+ "agent": "go-httpbin",
+ "url": targetURL,
+ },
+ },
+ "outdir": "/testdir/report",
+ "cases": includedTestCases,
+ "exclude-cases": excludedTestCases,
+ }
+
+ autobahnCfgFile, err := os.Create(path.Join(testDir, "autobahn.json"))
+ assert.NilError(t, err)
+ assert.NilError(t, json.NewEncoder(autobahnCfgFile).Encode(autobahnCfg))
+ autobahnCfgFile.Close()
+
+ pullCmd := exec.Command("docker", "pull", autobahnImage)
+ runCmd(t, pullCmd)
+
+ testCmd := exec.Command(
+ "docker",
+ "run",
+ "--net=host",
+ "--rm",
+ "-v", testDir+":/testdir:rw",
+ autobahnImage,
+ "wstest", "-m", "fuzzingclient", "--spec", "/testdir/autobahn.json",
+ )
+ runCmd(t, testCmd)
+
+ summary := loadSummary(t, testDir)
+ if len(summary) == 0 {
+ t.Fatalf("empty autobahn test summary; check autobahn logs for problems connecting to test server at %q", targetURL)
+ }
+
+ for _, results := range summary {
+ for caseName, result := range results {
+ result := result
+ t.Run("autobahn/"+caseName, func(t *testing.T) {
+ if result.Behavior == "FAILED" || result.BehaviorClose == "FAILED" {
+ report := loadReport(t, testDir, result.ReportFile)
+ t.Errorf("description: %s", report.Description)
+ t.Errorf("expectation: %s", report.Expectation)
+ t.Errorf("result: %s", report.Result)
+ t.Errorf("close: %s", report.ResultClose)
+ }
+ })
+ }
+ }
+
+ t.Logf("autobahn test report: %s", path.Join(testDir, "report/index.html"))
+ if os.Getenv("AUTOBAHN_OPEN_REPORT") != "" {
+ runCmd(t, exec.Command("open", path.Join(testDir, "report/index.html")))
+ }
+}
+
+// newAutobahnTargetURL returns the URL that the autobahn test suite should use
+// to connect to the given httptest server.
+//
+// On Macs, the docker engine is running inside an implicit VM, so even with
+// --net=host, we need to use the special hostname to escape the VM.
+//
+// See the Docker Desktop docs[1] for more information. This same special
+// hostname seems to work across Docker Desktop for Mac, OrbStack, and Colima.
+//
+// [1]: https://docs.docker.com/desktop/networking/#i-want-to-connect-from-a-container-to-a-service-on-the-host
+func newAutobahnTargetURL(t *testing.T, srv *httptest.Server) string {
+ t.Helper()
+ u, err := url.Parse(srv.URL)
+ assert.NilError(t, err)
+
+ var host string
+ switch runtime.GOOS {
+ case "darwin":
+ host = "host.docker.internal"
+ default:
+ host = "127.0.0.1"
+ }
+
+ return fmt.Sprintf("ws://%s:%s/websocket/echo", host, u.Port())
+}
+
+func runCmd(t *testing.T, cmd *exec.Cmd) {
+ t.Helper()
+ t.Logf("running command: %s", cmd.String())
+ cmd.Stdout = os.Stderr
+ cmd.Stderr = os.Stderr
+ assert.NilError(t, cmd.Run())
+}
+
+func newTestDir(t *testing.T) string {
+ t.Helper()
+
+ // package tests are run with the package as the working directory, but we
+ // want to store our integration test output in the repo root
+ testDir, err := filepath.Abs(path.Join(
+ "..", "..", ".integrationtests", fmt.Sprintf("autobahn-test-%d", time.Now().Unix()),
+ ))
+
+ assert.NilError(t, err)
+ assert.NilError(t, os.MkdirAll(testDir, 0o755))
+ return testDir
+}
+
+func loadSummary(t *testing.T, testDir string) autobahnReportSummary {
+ t.Helper()
+ f, err := os.Open(path.Join(testDir, "report", "index.json"))
+ assert.NilError(t, err)
+ defer f.Close()
+ var summary autobahnReportSummary
+ assert.NilError(t, json.NewDecoder(f).Decode(&summary))
+ return summary
+}
+
+func loadReport(t *testing.T, testDir string, reportFile string) autobahnReportResult {
+ t.Helper()
+ reportPath := path.Join(testDir, "report", reportFile)
+ t.Logf("report path: %s", reportPath)
+ f, err := os.Open(reportPath)
+ assert.NilError(t, err)
+ var report autobahnReportResult
+ assert.NilError(t, json.NewDecoder(f).Decode(&report))
+ return report
+}
+
+type autobahnReportSummary map[string]map[string]autobahnReportResult
+
+type autobahnReportResult struct {
+ Behavior string `json:"behavior"`
+ BehaviorClose string `json:"behaviorClose"`
+ Description string `json:"description"`
+ Expectation string `json:"expectation"`
+ ReportFile string `json:"reportfile"`
+ Result string `json:"result"`
+ ResultClose string `json:"resultClose"`
+}
diff --git a/httpbin/websocket/websocket_test.go b/httpbin/websocket/websocket_test.go
new file mode 100644
index 00000000..cb1084e0
--- /dev/null
+++ b/httpbin/websocket/websocket_test.go
@@ -0,0 +1,246 @@
+package websocket_test
+
+import (
+ "bufio"
+ "fmt"
+ "net"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/mccutchen/go-httpbin/v2/httpbin/websocket"
+ "github.com/mccutchen/go-httpbin/v2/internal/testing/assert"
+)
+
+func TestHandshake(t *testing.T) {
+ testCases := map[string]struct {
+ reqHeaders map[string]string
+ wantStatus int
+ wantRespHeaders map[string]string
+ }{
+ "valid handshake": {
+ reqHeaders: map[string]string{
+ "Connection": "upgrade",
+ "Upgrade": "websocket",
+ "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
+ "Sec-WebSocket-Version": "13",
+ },
+ wantRespHeaders: map[string]string{
+ "Connection": "upgrade",
+ "Upgrade": "websocket",
+ "Sec-Websocket-Accept": "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
+ },
+ wantStatus: http.StatusSwitchingProtocols,
+ },
+ "valid handshake, header values case insensitive": {
+ reqHeaders: map[string]string{
+ "Connection": "Upgrade",
+ "Upgrade": "WebSocket",
+ "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
+ "Sec-WebSocket-Version": "13",
+ },
+ wantRespHeaders: map[string]string{
+ "Connection": "upgrade",
+ "Upgrade": "websocket",
+ "Sec-Websocket-Accept": "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=",
+ },
+ wantStatus: http.StatusSwitchingProtocols,
+ },
+ "missing Connection header": {
+ reqHeaders: map[string]string{
+ "Upgrade": "websocket",
+ "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
+ "Sec-WebSocket-Version": "13",
+ },
+ wantStatus: http.StatusBadRequest,
+ },
+ "incorrect Connection header": {
+ reqHeaders: map[string]string{
+ "Connection": "close",
+ "Upgrade": "websocket",
+ "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
+ "Sec-WebSocket-Version": "13",
+ },
+ wantStatus: http.StatusBadRequest,
+ },
+ "missing Upgrade header": {
+ reqHeaders: map[string]string{
+ "Connection": "Upgrade",
+ "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
+ "Sec-WebSocket-Version": "13",
+ },
+ wantStatus: http.StatusBadRequest,
+ },
+ "incorrect Upgrade header": {
+ reqHeaders: map[string]string{
+ "Connection": "Upgrade",
+ "Upgrade": "http/2",
+ "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
+ "Sec-WebSocket-Version": "13",
+ },
+ wantStatus: http.StatusBadRequest,
+ },
+ "missing version": {
+ reqHeaders: map[string]string{
+ "Connection": "upgrade",
+ "Upgrade": "websocket",
+ "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
+ },
+ wantStatus: http.StatusBadRequest,
+ },
+ "incorrect version": {
+ reqHeaders: map[string]string{
+ "Connection": "upgrade",
+ "Upgrade": "websocket",
+ "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
+ "Sec-WebSocket-Version": "12",
+ },
+ wantStatus: http.StatusBadRequest,
+ },
+ "missing Sec-WebSocket-Key": {
+ reqHeaders: map[string]string{
+ "Connection": "upgrade",
+ "Upgrade": "websocket",
+ "Sec-WebSocket-Version": "13",
+ },
+ wantStatus: http.StatusBadRequest,
+ },
+ }
+ for name, tc := range testCases {
+ tc := tc
+ t.Run(name, func(t *testing.T) {
+ t.Parallel()
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ ws := websocket.New(w, r, websocket.Limits{})
+ if err := ws.Handshake(); err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+ ws.Serve(websocket.EchoHandler)
+ }))
+ defer srv.Close()
+
+ req, _ := http.NewRequest(http.MethodGet, srv.URL, nil)
+ for k, v := range tc.reqHeaders {
+ req.Header.Set(k, v)
+ }
+
+ resp, err := http.DefaultClient.Do(req)
+ assert.NilError(t, err)
+
+ assert.StatusCode(t, resp, tc.wantStatus)
+ for k, v := range tc.wantRespHeaders {
+ assert.Equal(t, resp.Header.Get(k), v, "incorrect value for %q response header", k)
+ }
+ })
+ }
+}
+
+func TestHandshakeOrder(t *testing.T) {
+ handshakeReq := httptest.NewRequest(http.MethodGet, "/websocket/echo", nil)
+ for k, v := range map[string]string{
+ "Connection": "upgrade",
+ "Upgrade": "websocket",
+ "Sec-WebSocket-Key": "dGhlIHNhbXBsZSBub25jZQ==",
+ "Sec-WebSocket-Version": "13",
+ } {
+ handshakeReq.Header.Set(k, v)
+ }
+
+ t.Run("double handshake", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ ws := websocket.New(w, handshakeReq, websocket.Limits{})
+
+ // first handshake succeeds
+ assert.NilError(t, ws.Handshake())
+ assert.Equal(t, w.Code, http.StatusSwitchingProtocols, "incorrect status code")
+
+ // second handshake fails
+ defer func() {
+ r := recover()
+ if r == nil {
+ t.Fatalf("expected to catch panic on double handshake")
+ }
+ assert.Equal(t, fmt.Sprint(r), "websocket: handshake already completed", "incorrect panic message")
+ }()
+ ws.Handshake()
+ })
+
+ t.Run("handshake not completed", func(t *testing.T) {
+ defer func() {
+ r := recover()
+ if r == nil {
+ t.Fatalf("expected to catch panic on Serve before Handshake")
+ }
+ assert.Equal(t, fmt.Sprint(r), "websocket: serve: handshake not completed", "incorrect panic message")
+ }()
+ w := httptest.NewRecorder()
+ websocket.New(w, handshakeReq, websocket.Limits{}).Serve(nil)
+ })
+
+ t.Run("http.Hijack not implemented", func(t *testing.T) {
+ // confirm that httptest.ResponseRecorder does not implmeent
+ // http.Hjijacker
+ var rw http.ResponseWriter = httptest.NewRecorder()
+ _, ok := rw.(http.Hijacker)
+ assert.Equal(t, ok, false, "expected httptest.ResponseRecorder not to implement http.Hijacker")
+
+ w := httptest.NewRecorder()
+ ws := websocket.New(w, handshakeReq, websocket.Limits{})
+
+ assert.NilError(t, ws.Handshake())
+ assert.Equal(t, w.Code, http.StatusSwitchingProtocols, "incorrect status code")
+
+ defer func() {
+ r := recover()
+ if r == nil {
+ t.Fatalf("expected to catch panic on when http.Hijack not implemented")
+ }
+ assert.Equal(t, fmt.Sprint(r), "websocket: serve: server does not support hijacking", "incorrect panic message")
+ }()
+ ws.Serve(nil)
+ })
+
+ t.Run("hijack failed", func(t *testing.T) {
+ w := &brokenHijackResponseWriter{}
+ ws := websocket.New(w, handshakeReq, websocket.Limits{})
+
+ assert.NilError(t, ws.Handshake())
+ assert.Equal(t, w.Code, http.StatusSwitchingProtocols, "incorrect status code")
+
+ defer func() {
+ r := recover()
+ if r == nil {
+ t.Fatalf("expected to catch panic on Serve before Handshake")
+ }
+ assert.Equal(t, fmt.Sprint(r), "websocket: serve: hijack failed: error hijacking connection", "incorrect panic message")
+ }()
+ ws.Serve(nil)
+ })
+}
+
+// brokenHijackResponseWriter implements just enough to satisfy the
+// http.ResponseWriter and http.Hijacker interfaces and get through the
+// handshake before failing to actually hijack the connection.
+type brokenHijackResponseWriter struct {
+ http.ResponseWriter
+ Code int
+}
+
+func (w *brokenHijackResponseWriter) WriteHeader(code int) {
+ w.Code = code
+}
+
+func (w *brokenHijackResponseWriter) Header() http.Header {
+ return http.Header{}
+}
+
+func (brokenHijackResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ return nil, nil, fmt.Errorf("error hijacking connection")
+}
+
+var (
+ _ http.ResponseWriter = &brokenHijackResponseWriter{}
+ _ http.Hijacker = &brokenHijackResponseWriter{}
+)