diff --git a/go.mod b/go.mod index 41d7730d39..d37a6ff09f 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b github.com/mr-tron/base58 v1.2.0 github.com/multiformats/go-base32 v0.1.0 - github.com/multiformats/go-multiaddr v0.13.0 + github.com/multiformats/go-multiaddr v0.14.0 github.com/multiformats/go-multiaddr-dns v0.4.0 github.com/multiformats/go-multiaddr-fmt v0.1.0 github.com/multiformats/go-multibase v0.2.0 diff --git a/go.sum b/go.sum index df6db73cff..7b0c780196 100644 --- a/go.sum +++ b/go.sum @@ -233,8 +233,8 @@ github.com/multiformats/go-base32 v0.1.0/go.mod h1:Kj3tFY6zNr+ABYMqeUNeGvkIC/UYg github.com/multiformats/go-base36 v0.2.0 h1:lFsAbNOGeKtuKozrtBsAkSVhv1p9D0/qedU9rQyccr0= github.com/multiformats/go-base36 v0.2.0/go.mod h1:qvnKE++v+2MWCfePClUEjE78Z7P2a1UV0xHgWc0hkp4= github.com/multiformats/go-multiaddr v0.1.1/go.mod h1:aMKBKNEYmzmDmxfX88/vz+J5IU55txyt0p4aiWVohjo= -github.com/multiformats/go-multiaddr v0.13.0 h1:BCBzs61E3AGHcYYTv8dqRH43ZfyrqM8RXVPT8t13tLQ= -github.com/multiformats/go-multiaddr v0.13.0/go.mod h1:sBXrNzucqkFJhvKOiwwLyqamGa/P5EIXNPLovyhQCII= +github.com/multiformats/go-multiaddr v0.14.0 h1:bfrHrJhrRuh/NXH5mCnemjpbGjzRw/b+tJFOD41g2tU= +github.com/multiformats/go-multiaddr v0.14.0/go.mod h1:6EkVAxtznq2yC3QT5CM1UTAwG0GTP3EWAIcjHuzQ+r4= github.com/multiformats/go-multiaddr-dns v0.4.0 h1:P76EJ3qzBXpUXZ3twdCDx/kvagMsNo0LMFXpyms/zgU= github.com/multiformats/go-multiaddr-dns v0.4.0/go.mod h1:7hfthtB4E4pQwirrz+J0CcDUfbWzTqEzVyYKKIKpgkc= github.com/multiformats/go-multiaddr-fmt v0.1.0 h1:WLEFClPycPkp4fnIzoFoV9FVd49/eQsuaL3/CWe167E= diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 7cfab5f3ca..e353ba6526 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -31,6 +31,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" + libp2pmemory "github.com/libp2p/go-libp2p/p2p/transport/memory" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "go.uber.org/mock/gomock" @@ -156,6 +157,21 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "Memory", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.Transport(libp2pmemory.NewTransport)) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/memory/1234")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, } func TestPing(t *testing.T) { diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go new file mode 100644 index 0000000000..515fb43625 --- /dev/null +++ b/p2p/transport/memory/conn.go @@ -0,0 +1,130 @@ +package memory + +import ( + "context" + "sync" + "sync/atomic" + + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" +) + +type conn struct { + id int64 + + transport *transport + scope network.ConnManagementScope + + localPeer peer.ID + localMultiaddr ma.Multiaddr + + remotePeerID peer.ID + remotePubKey ic.PubKey + remoteMultiaddr ma.Multiaddr + + mu sync.Mutex + + closed atomic.Bool + closeOnce sync.Once + + streamC chan *stream + streams map[int64]network.MuxedStream +} + +var _ tpt.CapableConn = &conn{} + +func newConnection( + t *transport, + s *stream, + localPeer peer.ID, + localMultiaddr ma.Multiaddr, + remotePubKey ic.PubKey, + remotePeer peer.ID, + remoteMultiaddr ma.Multiaddr, +) *conn { + c := &conn{ + id: connCounter.Add(1), + transport: t, + localPeer: localPeer, + localMultiaddr: localMultiaddr, + remotePubKey: remotePubKey, + remotePeerID: remotePeer, + remoteMultiaddr: remoteMultiaddr, + streamC: make(chan *stream, 1), + streams: make(map[int64]network.MuxedStream), + } + + c.addStream(s.id, s) + return c +} + +func (c *conn) Close() error { + c.closed.Store(true) + for _, s := range c.streams { + s.Close() + } + + return nil +} + +func (c *conn) IsClosed() bool { + return c.closed.Load() +} + +func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { + inStream, outStream := newStreamPair() + + c.streamC <- inStream + return outStream, nil +} + +func (c *conn) AcceptStream() (network.MuxedStream, error) { + in := <-c.streamC + id := streamCounter.Add(1) + c.addStream(id, in) + return in, nil +} + +func (c *conn) LocalPeer() peer.ID { return c.localPeer } + +// RemotePeer returns the peer ID of the remote peer. +func (c *conn) RemotePeer() peer.ID { return c.remotePeerID } + +// RemotePublicKey returns the public pkey of the remote peer. +func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } + +// LocalMultiaddr returns the local Multiaddr associated +func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr } + +// RemoteMultiaddr returns the remote Multiaddr associated +func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr } + +func (c *conn) Transport() tpt.Transport { + return c.transport +} + +func (c *conn) Scope() network.ConnScope { + return c.scope +} + +// ConnState is the state of security connection. +func (c *conn) ConnState() network.ConnectionState { + return network.ConnectionState{Transport: "memory"} +} + +func (c *conn) addStream(id int64, stream network.MuxedStream) { + c.mu.Lock() + defer c.mu.Unlock() + + c.streams[id] = stream +} + +func (c *conn) removeStream(id int64) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.streams, id) +} diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go new file mode 100644 index 0000000000..54417e2a8b --- /dev/null +++ b/p2p/transport/memory/listener.go @@ -0,0 +1,71 @@ +package memory + +import ( + "context" + "net" + "sync" + + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" +) + +const ( + listenerQueueSize = 16 +) + +type listener struct { + t *transport + ctx context.Context + cancel context.CancelFunc + laddr ma.Multiaddr + + mu sync.Mutex + connCh chan *conn + connections map[int64]*conn +} + +func (l *listener) Multiaddr() ma.Multiaddr { + return l.laddr +} + +func newListener(t *transport, laddr ma.Multiaddr) *listener { + ctx, cancel := context.WithCancel(context.Background()) + return &listener{ + t: t, + ctx: ctx, + cancel: cancel, + laddr: laddr, + connCh: make(chan *conn, listenerQueueSize), + connections: make(map[int64]*conn), + } +} + +// Accept accepts new connections. +func (l *listener) Accept() (tpt.CapableConn, error) { + select { + case <-l.ctx.Done(): + return nil, tpt.ErrListenerClosed + case c, ok := <-l.connCh: + if !ok { + return nil, tpt.ErrListenerClosed + } + + l.mu.Lock() + defer l.mu.Unlock() + + c.transport = l.t + l.connections[c.id] = c + return c, nil + } +} + +// Close closes the listener. +func (l *listener) Close() error { + l.cancel() + return nil +} + +// Addr returns the address of this listener. +func (l *listener) Addr() net.Addr { + return nil +} diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go new file mode 100644 index 0000000000..66d8879f88 --- /dev/null +++ b/p2p/transport/memory/stream.go @@ -0,0 +1,132 @@ +package memory + +import ( + "errors" + "io" + "net" + "sync/atomic" + "time" + + "github.com/libp2p/go-libp2p/core/network" +) + +// stream implements network.Stream +type stream struct { + id int64 + + write chan byte + read chan byte + + reset chan struct{} + closeRead chan struct{} + closeWrite chan struct{} + closed atomic.Bool +} + +var ErrClosed = errors.New("stream closed") + +func newStreamPair() (*stream, *stream) { + ra, rb := make(chan byte, 4096), make(chan byte, 4096) + wa, wb := rb, ra + + in := newStream(rb, wb, network.DirInbound) + out := newStream(ra, wa, network.DirOutbound) + return in, out +} + +func newStream(r, w chan byte, _ network.Direction) *stream { + s := &stream{ + id: streamCounter.Add(1), + read: r, + write: w, + reset: make(chan struct{}, 1), + closeRead: make(chan struct{}, 1), + closeWrite: make(chan struct{}, 1), + } + + return s +} + +// How to handle errors with writes? +func (s *stream) Write(p []byte) (n int, err error) { + if s.closed.Load() { + return 0, ErrClosed + } + + for i := 0; i < len(p); i++ { + select { + case <-s.reset: + err = network.ErrReset + return + case <-s.closeWrite: + err = ErrClosed + return + case s.write <- p[i]: + n = i + default: + err = io.ErrClosedPipe + } + } + + return n + 1, err +} + +func (s *stream) Read(p []byte) (n int, err error) { + if s.closed.Load() { + return 0, ErrClosed + } + + for n = 0; n < len(p); n++ { + select { + case <-s.reset: + err = network.ErrReset + return + case <-s.closeRead: + err = ErrClosed + return + case b, ok := <-s.read: + if !ok { + err = io.EOF + return + } + p[n] = b + default: + err = io.EOF + return + } + } + + return +} + +func (s *stream) CloseWrite() error { + s.closeWrite <- struct{}{} + return nil +} + +func (s *stream) CloseRead() error { + s.closeRead <- struct{}{} + return nil +} + +func (s *stream) Close() error { + s.closed.Store(true) + return nil +} + +func (s *stream) Reset() error { + s.reset <- struct{}{} + return nil +} + +func (s *stream) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (s *stream) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (s *stream) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go new file mode 100644 index 0000000000..cd5149c685 --- /dev/null +++ b/p2p/transport/memory/stream_test.go @@ -0,0 +1,48 @@ +package memory + +import ( + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStreamSimpleReadWriteClose(t *testing.T) { + clientStr, serverStr := newStreamPair() + + // send a foobar from the client + n, err := clientStr.Write([]byte("foobar")) + require.NoError(t, err) + require.Equal(t, 6, n) + require.NoError(t, clientStr.CloseWrite()) + + // writing after closing should error + _, err = clientStr.Write([]byte("foobar")) + require.Error(t, err) + + // now read all the data on the server side + b, err := io.ReadAll(serverStr) + require.NoError(t, err) + require.Equal(t, []byte("foobar"), b) + // reading again should give another io.EOF + n, err = serverStr.Read(make([]byte, 10)) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + + // send something back + _, err = serverStr.Write([]byte("lorem ipsum")) + require.NoError(t, err) + require.NoError(t, serverStr.CloseWrite()) + + // and read it at the client + //require.False(t, clientDone.Load()) + b, err = io.ReadAll(clientStr) + require.NoError(t, err) + require.Equal(t, []byte("lorem ipsum"), b) + + // stream is only cleaned up on calling Close or Reset + clientStr.Close() + serverStr.Close() + // Need to call Close for cleanup. Otherwise the FIN_ACK is never read + require.NoError(t, serverStr.Close()) +} diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go new file mode 100644 index 0000000000..a13c737437 --- /dev/null +++ b/p2p/transport/memory/transport.go @@ -0,0 +1,211 @@ +package memory + +import ( + "context" + "errors" + "sync" + "sync/atomic" + + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/pnet" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" +) + +var ( + connCounter atomic.Int64 + streamCounter atomic.Int64 + listenerCounter atomic.Int64 + dialMatcher = mafmt.Base(ma.P_MEMORY) +) + +type hub struct { + mu sync.RWMutex + closeOnce sync.Once + pubKeys map[peer.ID]ic.PubKey + listeners map[string]*listener +} + +func (h *hub) addListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + h.listeners[addr] = l +} + +func (h *hub) removeListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.listeners, addr) +} + +func (h *hub) getListener(addr string) (*listener, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + l, ok := h.listeners[addr] + return l, ok +} + +func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) { + h.mu.Lock() + defer h.mu.Unlock() + + h.pubKeys[p] = pk +} + +func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + pk, ok := h.pubKeys[p] + return pk, ok +} + +func (h *hub) close() { + h.closeOnce.Do(func() { + h.mu.Lock() + defer h.mu.Unlock() + + for _, l := range h.listeners { + l.Close() + } + }) +} + +var memhub = &hub{ + listeners: make(map[string]*listener), + pubKeys: make(map[peer.ID]ic.PubKey), +} + +type transport struct { + psk pnet.PSK + rcmgr network.ResourceManager + localPeerID peer.ID + localPrivKey ic.PrivKey + localPubKey ic.PubKey + + mu sync.RWMutex + + connections map[int64]*conn +} + +func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { + if rcmgr == nil { + rcmgr = &network.NullResourceManager{} + } + + id, err := peer.IDFromPrivateKey(privKey) + if err != nil { + return nil, err + } + + memhub.addPubKey(id, privKey.GetPublic()) + return &transport{ + psk: psk, + rcmgr: rcmgr, + localPeerID: id, + localPrivKey: privKey, + localPubKey: privKey.GetPublic(), + connections: make(map[int64]*conn), + }, nil +} + +func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + if err != nil { + return nil, err + } + + c, err := t.dialWithScope(ctx, raddr, p, scope) + if err != nil { + return nil, err + } + + return c, nil +} + +func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { + if err := scope.SetPeer(rpid); err != nil { + return nil, err + } + + // TODO: Check if there is an existing listener for this address + t.mu.RLock() + defer t.mu.RUnlock() + l, ok := memhub.getListener(raddr.String()) + if !ok { + return nil, errors.New("failed to get listener") + } + + remotePubKey, ok := memhub.getPubKey(rpid) + if !ok { + return nil, errors.New("failed to get remote public key") + } + + inStream, outStream := newStreamPair() + inConn := newConnection(t, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr) + outConn := newConnection(nil, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil) + l.connCh <- outConn + + return inConn, nil +} + +func (t *transport) CanDial(addr ma.Multiaddr) bool { + return dialMatcher.Matches(addr) +} + +func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { + // TODO: Check if we need to add scope via conn mngr + l := newListener(t, laddr) + + t.mu.Lock() + defer t.mu.Unlock() + + memhub.addListener(laddr.String(), l) + + return l, nil +} + +func (t *transport) Proxy() bool { + return false +} + +// Protocols returns the set of protocols handled by this transport. +func (t *transport) Protocols() []int { + return []int{ma.P_MEMORY} +} + +func (t *transport) String() string { + return "MemoryTransport" +} + +func (t *transport) Close() error { + // TODO: Go trough all listeners and close them + memhub.close() + + for _, c := range t.connections { + c.Close() + delete(t.connections, c.id) + } + + return nil +} + +func (t *transport) addConn(c *conn) { + t.mu.Lock() + defer t.mu.Unlock() + + t.connections[c.id] = c +} + +func (t *transport) removeConn(c *conn) { + t.mu.Lock() + defer t.mu.Unlock() + + delete(t.connections, c.id) +} diff --git a/p2p/transport/memory/transport_test.go b/p2p/transport/memory/transport_test.go new file mode 100644 index 0000000000..f83f0d1280 --- /dev/null +++ b/p2p/transport/memory/transport_test.go @@ -0,0 +1,70 @@ +package memory + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "io" + "testing" + + ic "github.com/libp2p/go-libp2p/core/crypto" + tpt "github.com/libp2p/go-libp2p/core/transport" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func getTransport(t *testing.T) tpt.Transport { + t.Helper() + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) + require.NoError(t, err) + tr, err := NewTransport(key, nil, nil) + require.NoError(t, err) + return tr +} + +func TestMemoryProtocol(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() + + protocols := tr.Protocols() + if len(protocols) > 1 { + t.Fatalf("expected at most one protocol, got %v", protocols) + } + + if protocols[0] != ma.P_MEMORY { + t.Fatalf("expected the supported protocol to be memory, got %d", protocols[0]) + } +} + +func TestCanDial(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() + + invalid := []string{ + "/ip4/127.0.0.1/udp/1234", + "/ip4/5.5.5.5/tcp/1234", + "/dns/google.com/udp/443/quic-v1", + "/ip4/127.0.0.1/udp/1234/quic", + } + valid := []string{ + "/memory/1234", + "/memory/1337123", + } + for _, s := range invalid { + invalidAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if tr.CanDial(invalidAddr) { + t.Errorf("didn't expect to be able to dial a non-memory address (%s)", invalidAddr) + } + } + for _, s := range valid { + validAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if !tr.CanDial(validAddr) { + t.Errorf("expected to be able to dial memory address (%s)", validAddr) + } + } +}