diff --git a/p2p/http/libp2phttp.go b/p2p/http/libp2phttp.go index d9cf0e54b6..d8158153ac 100644 --- a/p2p/http/libp2phttp.go +++ b/p2p/http/libp2phttp.go @@ -181,7 +181,9 @@ func (h *Host) httpTransportInit() { func (h *Host) serveMuxInit() { h.initializeServeMux.Do(func() { - h.ServeMux = http.NewServeMux() + if h.ServeMux == nil { + h.ServeMux = http.NewServeMux() + } }) } diff --git a/p2p/http/libp2phttp_test.go b/p2p/http/libp2phttp_test.go index 1114b5a9d9..8a95bed6d7 100644 --- a/p2p/http/libp2phttp_test.go +++ b/p2p/http/libp2phttp_test.go @@ -390,3 +390,32 @@ func selfSignedTLSConfig(t *testing.T) *tls.Config { } return tlsConfig } + +func TestCustomServeMux(t *testing.T) { + serveMux := http.NewServeMux() + serveMux.Handle("/ping/", httpping.Ping{}) + + server := libp2phttp.Host{ + ListenAddrs: []ma.Multiaddr{ma.StringCast("/ip4/127.0.0.1/tcp/0/http")}, + ServeMux: serveMux, + InsecureAllowHTTP: true, + } + server.WellKnownHandler.AddProtocolMeta(httpping.PingProtocolID, libp2phttp.ProtocolMeta{Path: "/ping/"}) + go func() { + server.Serve() + }() + defer server.Close() + + addrs := server.Addrs() + require.Equal(t, len(addrs), 1) + var clientHttpHost libp2phttp.Host + rt, err := clientHttpHost.NewConstrainedRoundTripper(peer.AddrInfo{Addrs: addrs}, libp2phttp.PreferHTTPTransport) + require.NoError(t, err) + + client := &http.Client{Transport: rt} + body := [32]byte{} + req, _ := http.NewRequest(http.MethodPost, "/ping/", bytes.NewReader(body[:])) + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode) +}