From 39983c2a3f2845b6c2dc7983591f458ad23ce652 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Tue, 22 Oct 2024 21:05:46 +0200 Subject: [PATCH 1/6] Add memory transport --- p2p/transport/memory/conn.go | 76 ++++++++++++++++++ p2p/transport/memory/stream.go | 124 ++++++++++++++++++++++++++++++ p2p/transport/memory/transport.go | 7 ++ 3 files changed, 207 insertions(+) create mode 100644 p2p/transport/memory/conn.go create mode 100644 p2p/transport/memory/stream.go create mode 100644 p2p/transport/memory/transport.go diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go new file mode 100644 index 0000000000..2665f081bb --- /dev/null +++ b/p2p/transport/memory/conn.go @@ -0,0 +1,76 @@ +package memory + +import ( + "context" + "sync" + + ic "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" +) + +type conn struct { + transport *transport + scope network.ConnManagementScope + + localPeer peer.ID + localMultiaddr ma.Multiaddr + + remotePeerID peer.ID + remotePubKey ic.PubKey + remoteMultiaddr ma.Multiaddr + + closed bool + closeOnce sync.Once +} + +var _ tpt.CapableConn = &conn{} + +func (c *conn) Close() error { + c.closeOnce.Do(func() { + c.closed = true + c.transport.removeConn(c) + }) + + return nil +} + +func (c *conn) IsClosed() bool { + return c.closed +} + +func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { + return newStream(), nil +} + +func (c *conn) AcceptStream() (network.MuxedStream, error) { + return nil, nil +} + +func (c *conn) LocalPeer() peer.ID { return c.localPeer } + +// RemotePeer returns the peer ID of the remote peer. +func (c *conn) RemotePeer() peer.ID { return c.remotePeerID } + +// RemotePublicKey returns the public key of the remote peer. +func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } + +// LocalMultiaddr returns the local Multiaddr associated +func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr } + +// RemoteMultiaddr returns the remote Multiaddr associated +func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr } + +func (c *conn) Transport() tpt.Transport { + // TODO: return c.transport + return nil +} + +func (c *conn) Scope() network.ConnScope { return c.scope } + +// ConnState is the state of security connection. +func (c *conn) ConnState() network.ConnectionState { + return network.ConnectionState{Transport: "memory"} +} diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go new file mode 100644 index 0000000000..6b8555dcba --- /dev/null +++ b/p2p/transport/memory/stream.go @@ -0,0 +1,124 @@ +package memory + +import ( + "errors" + "sync" + "time" + + "github.com/libp2p/go-libp2p/core/network" +) + +type stream struct { + inC <-chan []byte + outC chan<- []byte + + readCloseC chan struct{} + writeCloseC chan struct{} + + mu sync.Mutex + closed bool + + deadline time.Time + readDeadline time.Time + writeDeadline time.Time +} + +func newStream() *stream { + return &stream{ + inC: make(<-chan []byte), + outC: make(chan<- []byte), + readCloseC: make(chan struct{}), + writeCloseC: make(chan struct{}), + } +} + +func (s *stream) Read(b []byte) (n int, err error) { + if s.closed { + return 0, network.ErrReset + } + + select { + case <-s.readCloseC: + err = network.ErrReset + case r, ok := <-s.inC: + if !ok { + err = network.ErrReset + } else { + n = copy(b, r) + } + } + + return n, err +} + +func (s *stream) Write(b []byte) (n int, err error) { + select { + case <-s.writeCloseC: + err = network.ErrReset + case s.outC <- b: + n = len(b) + default: + err = network.ErrReset + } + + return n, err +} + +func (s *stream) Reset() error { + s.CloseWrite() + s.CloseRead() + return nil +} + +func (s *stream) Close() error { + s.CloseRead() + + s.mu.Lock() + s.closed = true + s.mu.Unlock() + + return nil +} + +func (s *stream) CloseRead() error { + select { + case s.readCloseC <- struct{}{}: + default: + return errors.New("failed to close stream read") + } + + return nil +} + +func (s *stream) CloseWrite() error { + select { + case s.writeCloseC <- struct{}{}: + default: + return errors.New("failed to close stream write") + } + + return nil +} + +func (s *stream) SetDeadline(deadline time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.deadline = deadline + return nil +} + +func (s *stream) SetReadDeadline(readDeadline time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.readDeadline = readDeadline + return nil +} +func (s *stream) SetWriteDeadline(writeDeadline time.Time) error { + s.mu.Lock() + defer s.mu.Unlock() + + s.writeDeadline = writeDeadline + return nil +} diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go new file mode 100644 index 0000000000..d5850d7fb0 --- /dev/null +++ b/p2p/transport/memory/transport.go @@ -0,0 +1,7 @@ +package memory + +type transport struct { +} + +func (t *transport) removeConn(c *conn) { +} From 37899d3c0d9cd1fdeb54b6e955deb74a833757c7 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Wed, 23 Oct 2024 22:46:00 +0200 Subject: [PATCH 2/6] Daily commit --- p2p/transport/memory/conn.go | 65 ++++++++++++++++++--- p2p/transport/memory/listener.go | 62 +++++++++++++++++++++ p2p/transport/memory/stream.go | 63 ++++++++++----------- p2p/transport/memory/transport.go | 93 +++++++++++++++++++++++++++++++ 4 files changed, 240 insertions(+), 43 deletions(-) create mode 100644 p2p/transport/memory/listener.go diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index 2665f081bb..6dac7f87a1 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -3,6 +3,7 @@ package memory import ( "context" "sync" + "sync/atomic" ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" @@ -12,6 +13,8 @@ import ( ) type conn struct { + id int32 + transport *transport scope network.ConnManagementScope @@ -22,15 +25,35 @@ type conn struct { remotePubKey ic.PubKey remoteMultiaddr ma.Multiaddr - closed bool + isClosed atomic.Bool closeOnce sync.Once + + mu sync.Mutex + + streamC chan *stream + + nextStreamID atomic.Int32 + streams map[int32]network.MuxedStream } var _ tpt.CapableConn = &conn{} +func newConnection(id int32, s *stream) *conn { + c := &conn{ + id: id, + streamC: make(chan *stream, 1), + streams: make(map[int32]network.MuxedStream), + } + + streamID := c.nextStreamID.Add(1) + c.addStream(streamID, s) + + return c +} + func (c *conn) Close() error { c.closeOnce.Do(func() { - c.closed = true + c.isClosed.Store(true) c.transport.removeConn(c) }) @@ -38,15 +61,26 @@ func (c *conn) Close() error { } func (c *conn) IsClosed() bool { - return c.closed + return c.isClosed.Load() } func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { - return newStream(), nil + id := c.nextStreamID.Add(1) + ra := make(chan []byte) + wa := make(chan []byte) + + return newStream(id, ra, wa), nil } func (c *conn) AcceptStream() (network.MuxedStream, error) { - return nil, nil + select { + case in := <-c.streamC: + id := c.nextStreamID.Add(1) + s := newStream(id, in.outC, in.inC) + c.addStream(id, s) + + return s, nil + } } func (c *conn) LocalPeer() peer.ID { return c.localPeer } @@ -64,13 +98,28 @@ func (c *conn) LocalMultiaddr() ma.Multiaddr { return c.localMultiaddr } func (c *conn) RemoteMultiaddr() ma.Multiaddr { return c.remoteMultiaddr } func (c *conn) Transport() tpt.Transport { - // TODO: return c.transport - return nil + return c.transport } -func (c *conn) Scope() network.ConnScope { return c.scope } +func (c *conn) Scope() network.ConnScope { + return c.scope +} // ConnState is the state of security connection. func (c *conn) ConnState() network.ConnectionState { return network.ConnectionState{Transport: "memory"} } + +func (c *conn) addStream(id int32, stream network.MuxedStream) { + c.mu.Lock() + defer c.mu.Unlock() + + c.streams[id] = stream +} + +func (c *conn) removeStream(id int32) { + c.mu.Lock() + defer c.mu.Unlock() + + delete(c.streams, id) +} diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go new file mode 100644 index 0000000000..a53f317815 --- /dev/null +++ b/p2p/transport/memory/listener.go @@ -0,0 +1,62 @@ +package memory + +import ( + "context" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + "net" + "sync" + "sync/atomic" +) + +type listener struct { + ctx context.Context + cancel context.CancelFunc + laddr ma.Multiaddr + + mu sync.Mutex + connID atomic.Int32 + streamCh chan *stream + connections map[int32]*conn +} + +func (l *listener) Multiaddr() ma.Multiaddr { + return l.laddr +} + +func newListener(laddr ma.Multiaddr, streamCh chan *stream) tpt.Listener { + ctx, cancel := context.WithCancel(context.Background()) + return &listener{ + ctx: ctx, + cancel: cancel, + laddr: laddr, + streamCh: streamCh, + } +} + +// Accept accepts new connections. +func (l *listener) Accept() (tpt.CapableConn, error) { + select { + case s := <-l.streamCh: + l.mu.Lock() + defer l.mu.Unlock() + + id := l.connID.Add(1) + c := newConnection(id, s) + l.connections[id] = c + return nil, nil + case <-l.ctx.Done(): + return nil, l.ctx.Err() + } +} + +// Close closes the listener. +func (l *listener) Close() error { + l.cancel() + return nil +} + +// Addr returns the address of this listener. +func (l *listener) Addr() net.Addr { + return nil +} diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index 6b8555dcba..e816daf952 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -2,38 +2,36 @@ package memory import ( "errors" - "sync" + "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/network" ) type stream struct { - inC <-chan []byte - outC chan<- []byte + id int32 + + inC chan []byte + outC chan []byte readCloseC chan struct{} writeCloseC chan struct{} - mu sync.Mutex - closed bool - - deadline time.Time - readDeadline time.Time - writeDeadline time.Time + closed atomic.Bool } -func newStream() *stream { +func newStream(id int32, in, out chan []byte) *stream { return &stream{ - inC: make(<-chan []byte), - outC: make(chan<- []byte), + id: id, + inC: in, + outC: out, readCloseC: make(chan struct{}), writeCloseC: make(chan struct{}), } } func (s *stream) Read(b []byte) (n int, err error) { - if s.closed { + if s.closed.Load() { return 0, network.ErrReset } @@ -52,6 +50,10 @@ func (s *stream) Read(b []byte) (n int, err error) { } func (s *stream) Write(b []byte) (n int, err error) { + if s.closed.Load() { + return 0, network.ErrReset + } + select { case <-s.writeCloseC: err = network.ErrReset @@ -65,18 +67,21 @@ func (s *stream) Write(b []byte) (n int, err error) { } func (s *stream) Reset() error { - s.CloseWrite() - s.CloseRead() + if err := s.CloseWrite(); err != nil { + return err + } + if err := s.CloseRead(); err != nil { + return err + } return nil } func (s *stream) Close() error { - s.CloseRead() - - s.mu.Lock() - s.closed = true - s.mu.Unlock() + if err := s.CloseRead(); err != nil { + return err + } + s.closed.Store(true) return nil } @@ -100,25 +105,13 @@ func (s *stream) CloseWrite() error { return nil } -func (s *stream) SetDeadline(deadline time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - - s.deadline = deadline +func (s *stream) SetDeadline(_ time.Time) error { return nil } -func (s *stream) SetReadDeadline(readDeadline time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - - s.readDeadline = readDeadline +func (s *stream) SetReadDeadline(_ time.Time) error { return nil } -func (s *stream) SetWriteDeadline(writeDeadline time.Time) error { - s.mu.Lock() - defer s.mu.Unlock() - - s.writeDeadline = writeDeadline +func (s *stream) SetWriteDeadline(_ time.Time) error { return nil } diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index d5850d7fb0..e7f0f27e02 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -1,7 +1,100 @@ package memory +import ( + "context" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" + "sync" + "sync/atomic" +) + type transport struct { + rcmgr network.ResourceManager + + mu sync.RWMutex + + connID atomic.Int32 + listeners map[ma.Multiaddr]*listener + connections map[int32]*conn +} + +func NewTransport() *transport { + return &transport{ + connections: make(map[int32]*conn), + } +} + +func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { + scope, err := t.rcmgr.OpenConnection(network.DirOutbound, false, raddr) + if err != nil { + return nil, err + } + + c, err := t.dialWithScope(ctx, raddr, p, scope) + if err != nil { + return nil, err + } + + return c, nil +} + +func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { + if err := scope.SetPeer(p); err != nil { + return nil, err + } + + // TODO: Check if there is an existing listener for this address + t.mu.RLock() + defer t.mu.RUnlock() + l := t.listeners[raddr] + + in := make(chan []byte) + out := make(chan []byte) + s := newStream(0, in, out) + l.streamCh <- s + + return newConnection(0, s), nil +} + +func (t *transport) CanDial(addr ma.Multiaddr) bool { + return true +} + +func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { + // TODO: Figure out correct channel type + return newListener(laddr, nil), nil +} + +func (t *transport) Proxy() bool { + return false +} + +// Protocols returns the set of protocols handled by this transport. +func (t *transport) Protocols() []int { + return []int{777} +} + +func (t *transport) String() string { + return "MemoryTransport" +} + +func (t *transport) Close() error { + // TODO: Go trough all listeners and close them + return nil +} + +func (t *transport) addConn(c *conn) { + t.mu.Lock() + defer t.mu.Unlock() + + t.connections[c.id] = c } func (t *transport) removeConn(c *conn) { + t.mu.Lock() + defer t.mu.Unlock() + + delete(t.connections, c.id) } From 0df38f9190159806c237412c955d767f68a7d1d8 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Thu, 24 Oct 2024 22:03:54 +0200 Subject: [PATCH 3/6] Daily commit --- p2p/test/transport/transport_test.go | 16 ++++++++ p2p/transport/memory/conn.go | 12 +++--- p2p/transport/memory/listener.go | 24 ++++++----- p2p/transport/memory/stream.go | 39 +++++++----------- p2p/transport/memory/stream_test.go | 55 ++++++++++++++++++++++++++ p2p/transport/memory/transport.go | 59 ++++++++++++++++++++++------ 6 files changed, 151 insertions(+), 54 deletions(-) create mode 100644 p2p/transport/memory/stream_test.go diff --git a/p2p/test/transport/transport_test.go b/p2p/test/transport/transport_test.go index 7cfab5f3ca..e353ba6526 100644 --- a/p2p/test/transport/transport_test.go +++ b/p2p/test/transport/transport_test.go @@ -31,6 +31,7 @@ import ( "github.com/libp2p/go-libp2p/p2p/protocol/ping" "github.com/libp2p/go-libp2p/p2p/security/noise" tls "github.com/libp2p/go-libp2p/p2p/security/tls" + libp2pmemory "github.com/libp2p/go-libp2p/p2p/transport/memory" libp2pwebrtc "github.com/libp2p/go-libp2p/p2p/transport/webrtc" "go.uber.org/mock/gomock" @@ -156,6 +157,21 @@ var transportsToTest = []TransportTestCase{ return h }, }, + { + Name: "Memory", + HostGenerator: func(t *testing.T, opts TransportTestCaseOpts) host.Host { + libp2pOpts := transformOpts(opts) + libp2pOpts = append(libp2pOpts, libp2p.Transport(libp2pmemory.NewTransport)) + if opts.NoListen { + libp2pOpts = append(libp2pOpts, libp2p.NoListenAddrs) + } else { + libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings("/memory/1234")) + } + h, err := libp2p.New(libp2pOpts...) + require.NoError(t, err) + return h + }, + }, } func TestPing(t *testing.T) { diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index 6dac7f87a1..b01f05ed1a 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -2,6 +2,7 @@ package memory import ( "context" + "io" "sync" "sync/atomic" @@ -66,8 +67,8 @@ func (c *conn) IsClosed() bool { func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { id := c.nextStreamID.Add(1) - ra := make(chan []byte) - wa := make(chan []byte) + // TODO: Figure out how to exchange the pipes between the two streams + ra, wa := io.Pipe() return newStream(id, ra, wa), nil } @@ -76,10 +77,9 @@ func (c *conn) AcceptStream() (network.MuxedStream, error) { select { case in := <-c.streamC: id := c.nextStreamID.Add(1) - s := newStream(id, in.outC, in.inC) - c.addStream(id, s) + c.addStream(id, in) - return s, nil + return in, nil } } @@ -88,7 +88,7 @@ func (c *conn) LocalPeer() peer.ID { return c.localPeer } // RemotePeer returns the peer ID of the remote peer. func (c *conn) RemotePeer() peer.ID { return c.remotePeerID } -// RemotePublicKey returns the public key of the remote peer. +// RemotePublicKey returns the public pkey of the remote peer. func (c *conn) RemotePublicKey() ic.PubKey { return c.remotePubKey } // LocalMultiaddr returns the local Multiaddr associated diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go index a53f317815..8041e02aae 100644 --- a/p2p/transport/memory/listener.go +++ b/p2p/transport/memory/listener.go @@ -6,17 +6,16 @@ import ( ma "github.com/multiformats/go-multiaddr" "net" "sync" - "sync/atomic" ) type listener struct { + t *transport ctx context.Context cancel context.CancelFunc laddr ma.Multiaddr mu sync.Mutex - connID atomic.Int32 - streamCh chan *stream + connCh chan *conn connections map[int32]*conn } @@ -24,27 +23,26 @@ func (l *listener) Multiaddr() ma.Multiaddr { return l.laddr } -func newListener(laddr ma.Multiaddr, streamCh chan *stream) tpt.Listener { +func newListener(t *transport, laddr ma.Multiaddr) *listener { ctx, cancel := context.WithCancel(context.Background()) return &listener{ - ctx: ctx, - cancel: cancel, - laddr: laddr, - streamCh: streamCh, + t: t, + ctx: ctx, + cancel: cancel, + laddr: laddr, + connCh: make(chan *conn, listenerQueueSize), } } // Accept accepts new connections. func (l *listener) Accept() (tpt.CapableConn, error) { select { - case s := <-l.streamCh: + case c := <-l.connCh: l.mu.Lock() defer l.mu.Unlock() - id := l.connID.Add(1) - c := newConnection(id, s) - l.connections[id] = c - return nil, nil + l.connections[c.id] = c + return c, nil case <-l.ctx.Done(): return nil, l.ctx.Err() } diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index e816daf952..4e425ee5af 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -2,6 +2,7 @@ package memory import ( "errors" + "io" "sync/atomic" "time" @@ -11,8 +12,8 @@ import ( type stream struct { id int32 - inC chan []byte - outC chan []byte + r *io.PipeReader + w *io.PipeWriter readCloseC chan struct{} writeCloseC chan struct{} @@ -20,50 +21,40 @@ type stream struct { closed atomic.Bool } -func newStream(id int32, in, out chan []byte) *stream { +func newStream(id int32, r *io.PipeReader, w *io.PipeWriter) *stream { return &stream{ id: id, - inC: in, - outC: out, - readCloseC: make(chan struct{}), - writeCloseC: make(chan struct{}), + r: r, + w: w, + readCloseC: make(chan struct{}, 1), + writeCloseC: make(chan struct{}, 1), } } -func (s *stream) Read(b []byte) (n int, err error) { +func (s *stream) Read(b []byte) (int, error) { if s.closed.Load() { return 0, network.ErrReset } select { case <-s.readCloseC: - err = network.ErrReset - case r, ok := <-s.inC: - if !ok { - err = network.ErrReset - } else { - n = copy(b, r) - } + return 0, network.ErrReset + default: + return s.r.Read(b) } - - return n, err } -func (s *stream) Write(b []byte) (n int, err error) { +func (s *stream) Write(b []byte) (int, error) { if s.closed.Load() { return 0, network.ErrReset } select { case <-s.writeCloseC: - err = network.ErrReset - case s.outC <- b: - n = len(b) + return 0, network.ErrReset default: - err = network.ErrReset + return s.w.Write(b) } - - return n, err } func (s *stream) Reset() error { diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go new file mode 100644 index 0000000000..844000cd9d --- /dev/null +++ b/p2p/transport/memory/stream_test.go @@ -0,0 +1,55 @@ +package memory + +import ( + "github.com/stretchr/testify/require" + "io" + "testing" +) + +func TestStreamSimpleReadWriteClose(t *testing.T) { + //client, server := getDetachedDataChannels(t) + ra, wb := io.Pipe() + rb, wa := io.Pipe() + + clientStr := newStream(0, ra, wa) + serverStr := newStream(0, rb, wb) + + // send a foobar from the client + n, err := clientStr.Write([]byte("foobar")) + require.NoError(t, err) + require.Equal(t, 6, n) + require.NoError(t, clientStr.CloseWrite()) + // writing after closing should error + _, err = clientStr.Write([]byte("foobar")) + require.Error(t, err) + //require.False(t, clientDone.Load()) + + // now read all the data on the server side + b, err := io.ReadAll(serverStr) + require.NoError(t, err) + require.Equal(t, []byte("foobar"), b) + // reading again should give another io.EOF + n, err = serverStr.Read(make([]byte, 10)) + require.Zero(t, n) + require.ErrorIs(t, err, io.EOF) + //require.False(t, serverDone.Load()) + + // send something back + _, err = serverStr.Write([]byte("lorem ipsum")) + require.NoError(t, err) + require.NoError(t, serverStr.CloseWrite()) + + // and read it at the client + //require.False(t, clientDone.Load()) + b, err = io.ReadAll(clientStr) + require.NoError(t, err) + require.Equal(t, []byte("lorem ipsum"), b) + + // stream is only cleaned up on calling Close or Reset + clientStr.Close() + serverStr.Close() + //require.Eventually(t, func() bool { return clientDone.Load() }, 5*time.Second, 100*time.Millisecond) + // Need to call Close for cleanup. Otherwise the FIN_ACK is never read + require.NoError(t, serverStr.Close()) + //require.Eventually(t, func() bool { return serverDone.Load() }, 5*time.Second, 100*time.Millisecond) +} diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index e7f0f27e02..02eb1d24ee 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -2,28 +2,52 @@ package memory import ( "context" + ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" + "io" "sync" "sync/atomic" ) +const ( + listenerQueueSize = 16 +) + type transport struct { + pkey ic.PrivKey + pid peer.ID + psk pnet.PSK rcmgr network.ResourceManager mu sync.RWMutex connID atomic.Int32 - listeners map[ma.Multiaddr]*listener + listeners map[string]*listener connections map[int32]*conn } -func NewTransport() *transport { +func NewTransport(key ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { + if rcmgr == nil { + rcmgr = &network.NullResourceManager{} + } + + id, err := peer.IDFromPrivateKey(key) + if err != nil { + return nil, err + } + return &transport{ + rcmgr: rcmgr, + pid: id, + pkey: key, + psk: psk, + listeners: make(map[string]*listener), connections: make(map[int32]*conn), - } + }, nil } func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tpt.CapableConn, error) { @@ -48,14 +72,16 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee // TODO: Check if there is an existing listener for this address t.mu.RLock() defer t.mu.RUnlock() - l := t.listeners[raddr] + l := t.listeners[raddr.String()] - in := make(chan []byte) - out := make(chan []byte) - s := newStream(0, in, out) - l.streamCh <- s + ra, wb := io.Pipe() + rb, wa := io.Pipe() + in, out := newStream(0, ra, wb), newStream(0, rb, wa) + inId, outId := t.connID.Add(1), t.connID.Add(1) - return newConnection(0, s), nil + l.connCh <- newConnection(inId, in) + + return newConnection(outId, out), nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { @@ -63,8 +89,15 @@ func (t *transport) CanDial(addr ma.Multiaddr) bool { } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { - // TODO: Figure out correct channel type - return newListener(laddr, nil), nil + // TODO: Check if we need to add scope via conn mngr + l := newListener(t, laddr) + + t.mu.Lock() + defer t.mu.Unlock() + + t.listeners[laddr.String()] = l + + return l, nil } func (t *transport) Proxy() bool { @@ -82,6 +115,10 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them + for _, l := range t.listeners { + l.Close() + } + return nil } From 67da1924dd80f832e39850cf785701f3182ad6aa Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Wed, 6 Nov 2024 15:54:05 +0100 Subject: [PATCH 4/6] Upgrade go-multiaddr --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 41d7730d39..d37a6ff09f 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( github.com/mikioh/tcpinfo v0.0.0-20190314235526-30a79bb1804b github.com/mr-tron/base58 v1.2.0 github.com/multiformats/go-base32 v0.1.0 - github.com/multiformats/go-multiaddr v0.13.0 + github.com/multiformats/go-multiaddr v0.14.0 github.com/multiformats/go-multiaddr-dns v0.4.0 github.com/multiformats/go-multiaddr-fmt v0.1.0 github.com/multiformats/go-multibase v0.2.0 diff --git a/go.sum b/go.sum index df6db73cff..7b0c780196 100644 --- a/go.sum +++ b/go.sum @@ -233,8 +233,8 @@ github.com/multiformats/go-base32 v0.1.0/go.mod h1:Kj3tFY6zNr+ABYMqeUNeGvkIC/UYg github.com/multiformats/go-base36 v0.2.0 h1:lFsAbNOGeKtuKozrtBsAkSVhv1p9D0/qedU9rQyccr0= github.com/multiformats/go-base36 v0.2.0/go.mod h1:qvnKE++v+2MWCfePClUEjE78Z7P2a1UV0xHgWc0hkp4= github.com/multiformats/go-multiaddr v0.1.1/go.mod h1:aMKBKNEYmzmDmxfX88/vz+J5IU55txyt0p4aiWVohjo= -github.com/multiformats/go-multiaddr v0.13.0 h1:BCBzs61E3AGHcYYTv8dqRH43ZfyrqM8RXVPT8t13tLQ= -github.com/multiformats/go-multiaddr v0.13.0/go.mod h1:sBXrNzucqkFJhvKOiwwLyqamGa/P5EIXNPLovyhQCII= +github.com/multiformats/go-multiaddr v0.14.0 h1:bfrHrJhrRuh/NXH5mCnemjpbGjzRw/b+tJFOD41g2tU= +github.com/multiformats/go-multiaddr v0.14.0/go.mod h1:6EkVAxtznq2yC3QT5CM1UTAwG0GTP3EWAIcjHuzQ+r4= github.com/multiformats/go-multiaddr-dns v0.4.0 h1:P76EJ3qzBXpUXZ3twdCDx/kvagMsNo0LMFXpyms/zgU= github.com/multiformats/go-multiaddr-dns v0.4.0/go.mod h1:7hfthtB4E4pQwirrz+J0CcDUfbWzTqEzVyYKKIKpgkc= github.com/multiformats/go-multiaddr-fmt v0.1.0 h1:WLEFClPycPkp4fnIzoFoV9FVd49/eQsuaL3/CWe167E= From e2a5865925445b58964c86a225b11bc01f3e85a9 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Fri, 8 Nov 2024 09:07:55 +0100 Subject: [PATCH 5/6] Daily commit --- p2p/transport/memory/conn.go | 42 +++++---- p2p/transport/memory/listener.go | 30 ++++--- p2p/transport/memory/stream.go | 56 ++++++------ p2p/transport/memory/stream_test.go | 5 +- p2p/transport/memory/transport.go | 131 +++++++++++++++++++++------- 5 files changed, 175 insertions(+), 89 deletions(-) diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index b01f05ed1a..d864e93316 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -39,11 +39,24 @@ type conn struct { var _ tpt.CapableConn = &conn{} -func newConnection(id int32, s *stream) *conn { +func newConnection( + id int32, + s *stream, + localPeer peer.ID, + localMultiaddr ma.Multiaddr, + remotePubKey ic.PubKey, + remotePeer peer.ID, + remoteMultiaddr ma.Multiaddr, +) *conn { c := &conn{ - id: id, - streamC: make(chan *stream, 1), - streams: make(map[int32]network.MuxedStream), + id: id, + localPeer: localPeer, + localMultiaddr: localMultiaddr, + remotePubKey: remotePubKey, + remotePeerID: remotePeer, + remoteMultiaddr: remoteMultiaddr, + streamC: make(chan *stream, 1), + streams: make(map[int32]network.MuxedStream), } streamID := c.nextStreamID.Add(1) @@ -66,21 +79,20 @@ func (c *conn) IsClosed() bool { } func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { - id := c.nextStreamID.Add(1) - // TODO: Figure out how to exchange the pipes between the two streams - ra, wa := io.Pipe() + ra, wb := io.Pipe() + rb, wa := io.Pipe() + inConnId, outConnId := c.nextStreamID.Add(1), c.nextStreamID.Add(1) + inStream, outStream := newStream(inConnId, ra, wb), newStream(outConnId, rb, wa) - return newStream(id, ra, wa), nil + c.streamC <- inStream + return outStream, nil } func (c *conn) AcceptStream() (network.MuxedStream, error) { - select { - case in := <-c.streamC: - id := c.nextStreamID.Add(1) - c.addStream(id, in) - - return in, nil - } + in := <-c.streamC + id := c.nextStreamID.Add(1) + c.addStream(id, in) + return in, nil } func (c *conn) LocalPeer() peer.ID { return c.localPeer } diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go index 8041e02aae..39e8acfb29 100644 --- a/p2p/transport/memory/listener.go +++ b/p2p/transport/memory/listener.go @@ -2,10 +2,15 @@ package memory import ( "context" - tpt "github.com/libp2p/go-libp2p/core/transport" - ma "github.com/multiformats/go-multiaddr" "net" "sync" + + tpt "github.com/libp2p/go-libp2p/core/transport" + ma "github.com/multiformats/go-multiaddr" +) + +const ( + listenerQueueSize = 16 ) type listener struct { @@ -26,25 +31,30 @@ func (l *listener) Multiaddr() ma.Multiaddr { func newListener(t *transport, laddr ma.Multiaddr) *listener { ctx, cancel := context.WithCancel(context.Background()) return &listener{ - t: t, - ctx: ctx, - cancel: cancel, - laddr: laddr, - connCh: make(chan *conn, listenerQueueSize), + t: t, + ctx: ctx, + cancel: cancel, + laddr: laddr, + connCh: make(chan *conn, listenerQueueSize), + connections: make(map[int32]*conn), } } // Accept accepts new connections. func (l *listener) Accept() (tpt.CapableConn, error) { select { - case c := <-l.connCh: + case <-l.ctx.Done(): + return nil, tpt.ErrListenerClosed + case c, ok := <-l.connCh: + if !ok { + return nil, tpt.ErrListenerClosed + } + l.mu.Lock() defer l.mu.Unlock() l.connections[c.id] = c return c, nil - case <-l.ctx.Done(): - return nil, l.ctx.Err() } } diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index 4e425ee5af..101ae516da 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -1,7 +1,6 @@ package memory import ( - "errors" "io" "sync/atomic" "time" @@ -12,8 +11,9 @@ import ( type stream struct { id int32 - r *io.PipeReader - w *io.PipeWriter + r *io.PipeReader + w *io.PipeWriter + writeC chan []byte readCloseC chan struct{} writeCloseC chan struct{} @@ -22,26 +22,33 @@ type stream struct { } func newStream(id int32, r *io.PipeReader, w *io.PipeWriter) *stream { - return &stream{ + s := &stream{ id: id, r: r, w: w, + writeC: make(chan []byte, 1), readCloseC: make(chan struct{}, 1), writeCloseC: make(chan struct{}, 1), } + + go func() { + for { + select { + case b := <-s.writeC: + if _, err := w.Write(b); err != nil { + return + } + case <-s.writeCloseC: + return + } + } + }() + + return s } func (s *stream) Read(b []byte) (int, error) { - if s.closed.Load() { - return 0, network.ErrReset - } - - select { - case <-s.readCloseC: - return 0, network.ErrReset - default: - return s.r.Read(b) - } + return s.r.Read(b) } func (s *stream) Write(b []byte) (int, error) { @@ -52,8 +59,8 @@ func (s *stream) Write(b []byte) (int, error) { select { case <-s.writeCloseC: return 0, network.ErrReset - default: - return s.w.Write(b) + case s.writeC <- b: + return len(b), nil } } @@ -68,31 +75,22 @@ func (s *stream) Reset() error { } func (s *stream) Close() error { - if err := s.CloseRead(); err != nil { - return err - } - - s.closed.Store(true) + s.CloseRead() + s.CloseWrite() return nil } func (s *stream) CloseRead() error { - select { - case s.readCloseC <- struct{}{}: - default: - return errors.New("failed to close stream read") - } - - return nil + return s.r.CloseWithError(network.ErrReset) } func (s *stream) CloseWrite() error { select { case s.writeCloseC <- struct{}{}: default: - return errors.New("failed to close stream write") } + s.closed.Store(true) return nil } diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index 844000cd9d..33c3cbdc64 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -1,9 +1,10 @@ package memory import ( - "github.com/stretchr/testify/require" "io" "testing" + + "github.com/stretchr/testify/require" ) func TestStreamSimpleReadWriteClose(t *testing.T) { @@ -12,7 +13,7 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { rb, wa := io.Pipe() clientStr := newStream(0, ra, wa) - serverStr := newStream(0, rb, wb) + serverStr := newStream(1, rb, wb) // send a foobar from the client n, err := clientStr.Write([]byte("foobar")) diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index 02eb1d24ee..5016e3a7dd 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -2,51 +2,110 @@ package memory import ( "context" + "errors" + "io" + "sync" + "sync/atomic" + ic "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" - "io" - "sync" - "sync/atomic" ) -const ( - listenerQueueSize = 16 -) +type hub struct { + mu sync.RWMutex + closeOnce sync.Once + pubKeys map[peer.ID]ic.PubKey + listeners map[string]*listener +} + +func (h *hub) addListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + h.listeners[addr] = l +} + +func (h *hub) removeListener(addr string, l *listener) { + h.mu.Lock() + defer h.mu.Unlock() + + delete(h.listeners, addr) +} + +func (h *hub) getListener(addr string) (*listener, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + l, ok := h.listeners[addr] + return l, ok +} + +func (h *hub) addPubKey(p peer.ID, pk ic.PubKey) { + h.mu.Lock() + defer h.mu.Unlock() + + h.pubKeys[p] = pk +} + +func (h *hub) getPubKey(p peer.ID) (ic.PubKey, bool) { + h.mu.RLock() + defer h.mu.RUnlock() + + pk, ok := h.pubKeys[p] + return pk, ok +} + +func (h *hub) close() { + h.closeOnce.Do(func() { + h.mu.Lock() + defer h.mu.Unlock() + + for _, l := range h.listeners { + l.Close() + } + }) +} + +var memhub = &hub{ + listeners: make(map[string]*listener), + pubKeys: make(map[peer.ID]ic.PubKey), +} type transport struct { - pkey ic.PrivKey - pid peer.ID - psk pnet.PSK - rcmgr network.ResourceManager + psk pnet.PSK + rcmgr network.ResourceManager + localPeerID peer.ID + localPrivKey ic.PrivKey + localPubKey ic.PubKey mu sync.RWMutex connID atomic.Int32 - listeners map[string]*listener connections map[int32]*conn } -func NewTransport(key ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { +func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { if rcmgr == nil { rcmgr = &network.NullResourceManager{} } - id, err := peer.IDFromPrivateKey(key) + id, err := peer.IDFromPrivateKey(privKey) if err != nil { return nil, err } + memhub.addPubKey(id, privKey.GetPublic()) return &transport{ - rcmgr: rcmgr, - pid: id, - pkey: key, - psk: psk, - listeners: make(map[string]*listener), - connections: make(map[int32]*conn), + psk: psk, + rcmgr: rcmgr, + localPeerID: id, + localPrivKey: privKey, + localPubKey: privKey.GetPublic(), + connections: make(map[int32]*conn), }, nil } @@ -64,28 +123,37 @@ func (t *transport) Dial(ctx context.Context, raddr ma.Multiaddr, p peer.ID) (tp return c, nil } -func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { - if err := scope.SetPeer(p); err != nil { +func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid peer.ID, scope network.ConnManagementScope) (tpt.CapableConn, error) { + if err := scope.SetPeer(rpid); err != nil { return nil, err } // TODO: Check if there is an existing listener for this address t.mu.RLock() defer t.mu.RUnlock() - l := t.listeners[raddr.String()] + l, ok := memhub.getListener(raddr.String()) + if !ok { + return nil, errors.New("failed to get listener") + } + + remotePubKey, ok := memhub.getPubKey(rpid) + if !ok { + return nil, errors.New("failed to get remote public key") + } ra, wb := io.Pipe() rb, wa := io.Pipe() - in, out := newStream(0, ra, wb), newStream(0, rb, wa) - inId, outId := t.connID.Add(1), t.connID.Add(1) + inConnId, outConnId := t.connID.Add(1), t.connID.Add(1) + inStream, outStream := newStream(0, ra, wb), newStream(0, rb, wa) - l.connCh <- newConnection(inId, in) + l.connCh <- newConnection(inConnId, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil) - return newConnection(outId, out), nil + return newConnection(outConnId, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr), nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { - return true + _, exists := memhub.getListener(addr.String()) + return exists } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { @@ -95,7 +163,7 @@ func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { t.mu.Lock() defer t.mu.Unlock() - t.listeners[laddr.String()] = l + memhub.addListener(laddr.String(), l) return l, nil } @@ -106,7 +174,7 @@ func (t *transport) Proxy() bool { // Protocols returns the set of protocols handled by this transport. func (t *transport) Protocols() []int { - return []int{777} + return []int{ma.P_MEMORY} } func (t *transport) String() string { @@ -115,10 +183,7 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them - for _, l := range t.listeners { - l.Close() - } - + memhub.close() return nil } From 079bd3e5657177fa6e5d6fc5c40e476eaa950953 Mon Sep 17 00:00:00 2001 From: Srdjan S Date: Sat, 16 Nov 2024 15:56:56 +0100 Subject: [PATCH 6/6] Use plain channels to send data between streams --- p2p/transport/memory/conn.go | 45 ++++---- p2p/transport/memory/listener.go | 5 +- p2p/transport/memory/stream.go | 144 +++++++++++++++---------- p2p/transport/memory/stream_test.go | 12 +-- p2p/transport/memory/transport.go | 35 +++--- p2p/transport/memory/transport_test.go | 70 ++++++++++++ 6 files changed, 201 insertions(+), 110 deletions(-) create mode 100644 p2p/transport/memory/transport_test.go diff --git a/p2p/transport/memory/conn.go b/p2p/transport/memory/conn.go index d864e93316..515fb43625 100644 --- a/p2p/transport/memory/conn.go +++ b/p2p/transport/memory/conn.go @@ -2,7 +2,6 @@ package memory import ( "context" - "io" "sync" "sync/atomic" @@ -14,7 +13,7 @@ import ( ) type conn struct { - id int32 + id int64 transport *transport scope network.ConnManagementScope @@ -26,21 +25,19 @@ type conn struct { remotePubKey ic.PubKey remoteMultiaddr ma.Multiaddr - isClosed atomic.Bool - closeOnce sync.Once - mu sync.Mutex - streamC chan *stream + closed atomic.Bool + closeOnce sync.Once - nextStreamID atomic.Int32 - streams map[int32]network.MuxedStream + streamC chan *stream + streams map[int64]network.MuxedStream } var _ tpt.CapableConn = &conn{} func newConnection( - id int32, + t *transport, s *stream, localPeer peer.ID, localMultiaddr ma.Multiaddr, @@ -49,40 +46,36 @@ func newConnection( remoteMultiaddr ma.Multiaddr, ) *conn { c := &conn{ - id: id, + id: connCounter.Add(1), + transport: t, localPeer: localPeer, localMultiaddr: localMultiaddr, remotePubKey: remotePubKey, remotePeerID: remotePeer, remoteMultiaddr: remoteMultiaddr, streamC: make(chan *stream, 1), - streams: make(map[int32]network.MuxedStream), + streams: make(map[int64]network.MuxedStream), } - streamID := c.nextStreamID.Add(1) - c.addStream(streamID, s) - + c.addStream(s.id, s) return c } func (c *conn) Close() error { - c.closeOnce.Do(func() { - c.isClosed.Store(true) - c.transport.removeConn(c) - }) + c.closed.Store(true) + for _, s := range c.streams { + s.Close() + } return nil } func (c *conn) IsClosed() bool { - return c.isClosed.Load() + return c.closed.Load() } func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { - ra, wb := io.Pipe() - rb, wa := io.Pipe() - inConnId, outConnId := c.nextStreamID.Add(1), c.nextStreamID.Add(1) - inStream, outStream := newStream(inConnId, ra, wb), newStream(outConnId, rb, wa) + inStream, outStream := newStreamPair() c.streamC <- inStream return outStream, nil @@ -90,7 +83,7 @@ func (c *conn) OpenStream(ctx context.Context) (network.MuxedStream, error) { func (c *conn) AcceptStream() (network.MuxedStream, error) { in := <-c.streamC - id := c.nextStreamID.Add(1) + id := streamCounter.Add(1) c.addStream(id, in) return in, nil } @@ -122,14 +115,14 @@ func (c *conn) ConnState() network.ConnectionState { return network.ConnectionState{Transport: "memory"} } -func (c *conn) addStream(id int32, stream network.MuxedStream) { +func (c *conn) addStream(id int64, stream network.MuxedStream) { c.mu.Lock() defer c.mu.Unlock() c.streams[id] = stream } -func (c *conn) removeStream(id int32) { +func (c *conn) removeStream(id int64) { c.mu.Lock() defer c.mu.Unlock() diff --git a/p2p/transport/memory/listener.go b/p2p/transport/memory/listener.go index 39e8acfb29..54417e2a8b 100644 --- a/p2p/transport/memory/listener.go +++ b/p2p/transport/memory/listener.go @@ -21,7 +21,7 @@ type listener struct { mu sync.Mutex connCh chan *conn - connections map[int32]*conn + connections map[int64]*conn } func (l *listener) Multiaddr() ma.Multiaddr { @@ -36,7 +36,7 @@ func newListener(t *transport, laddr ma.Multiaddr) *listener { cancel: cancel, laddr: laddr, connCh: make(chan *conn, listenerQueueSize), - connections: make(map[int32]*conn), + connections: make(map[int64]*conn), } } @@ -53,6 +53,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) { l.mu.Lock() defer l.mu.Unlock() + c.transport = l.t l.connections[c.id] = c return c, nil } diff --git a/p2p/transport/memory/stream.go b/p2p/transport/memory/stream.go index 101ae516da..66d8879f88 100644 --- a/p2p/transport/memory/stream.go +++ b/p2p/transport/memory/stream.go @@ -1,106 +1,132 @@ package memory import ( + "errors" "io" + "net" "sync/atomic" "time" "github.com/libp2p/go-libp2p/core/network" ) +// stream implements network.Stream type stream struct { - id int32 + id int64 - r *io.PipeReader - w *io.PipeWriter - writeC chan []byte + write chan byte + read chan byte - readCloseC chan struct{} - writeCloseC chan struct{} + reset chan struct{} + closeRead chan struct{} + closeWrite chan struct{} + closed atomic.Bool +} + +var ErrClosed = errors.New("stream closed") - closed atomic.Bool +func newStreamPair() (*stream, *stream) { + ra, rb := make(chan byte, 4096), make(chan byte, 4096) + wa, wb := rb, ra + + in := newStream(rb, wb, network.DirInbound) + out := newStream(ra, wa, network.DirOutbound) + return in, out } -func newStream(id int32, r *io.PipeReader, w *io.PipeWriter) *stream { +func newStream(r, w chan byte, _ network.Direction) *stream { s := &stream{ - id: id, - r: r, - w: w, - writeC: make(chan []byte, 1), - readCloseC: make(chan struct{}, 1), - writeCloseC: make(chan struct{}, 1), + id: streamCounter.Add(1), + read: r, + write: w, + reset: make(chan struct{}, 1), + closeRead: make(chan struct{}, 1), + closeWrite: make(chan struct{}, 1), } - go func() { - for { - select { - case b := <-s.writeC: - if _, err := w.Write(b); err != nil { - return - } - case <-s.writeCloseC: - return - } - } - }() - return s } -func (s *stream) Read(b []byte) (int, error) { - return s.r.Read(b) -} - -func (s *stream) Write(b []byte) (int, error) { +// How to handle errors with writes? +func (s *stream) Write(p []byte) (n int, err error) { if s.closed.Load() { - return 0, network.ErrReset + return 0, ErrClosed } - select { - case <-s.writeCloseC: - return 0, network.ErrReset - case s.writeC <- b: - return len(b), nil + for i := 0; i < len(p); i++ { + select { + case <-s.reset: + err = network.ErrReset + return + case <-s.closeWrite: + err = ErrClosed + return + case s.write <- p[i]: + n = i + default: + err = io.ErrClosedPipe + } } + + return n + 1, err } -func (s *stream) Reset() error { - if err := s.CloseWrite(); err != nil { - return err +func (s *stream) Read(p []byte) (n int, err error) { + if s.closed.Load() { + return 0, ErrClosed } - if err := s.CloseRead(); err != nil { - return err + + for n = 0; n < len(p); n++ { + select { + case <-s.reset: + err = network.ErrReset + return + case <-s.closeRead: + err = ErrClosed + return + case b, ok := <-s.read: + if !ok { + err = io.EOF + return + } + p[n] = b + default: + err = io.EOF + return + } } - return nil + + return } -func (s *stream) Close() error { - s.CloseRead() - s.CloseWrite() +func (s *stream) CloseWrite() error { + s.closeWrite <- struct{}{} return nil } func (s *stream) CloseRead() error { - return s.r.CloseWithError(network.ErrReset) + s.closeRead <- struct{}{} + return nil } -func (s *stream) CloseWrite() error { - select { - case s.writeCloseC <- struct{}{}: - default: - } - +func (s *stream) Close() error { s.closed.Store(true) return nil } -func (s *stream) SetDeadline(_ time.Time) error { +func (s *stream) Reset() error { + s.reset <- struct{}{} return nil } -func (s *stream) SetReadDeadline(_ time.Time) error { - return nil +func (s *stream) SetDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } -func (s *stream) SetWriteDeadline(_ time.Time) error { - return nil + +func (s *stream) SetReadDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} +} + +func (s *stream) SetWriteDeadline(t time.Time) error { + return &net.OpError{Op: "set", Net: "pipe", Source: nil, Addr: nil, Err: errors.New("deadline not supported")} } diff --git a/p2p/transport/memory/stream_test.go b/p2p/transport/memory/stream_test.go index 33c3cbdc64..cd5149c685 100644 --- a/p2p/transport/memory/stream_test.go +++ b/p2p/transport/memory/stream_test.go @@ -8,22 +8,17 @@ import ( ) func TestStreamSimpleReadWriteClose(t *testing.T) { - //client, server := getDetachedDataChannels(t) - ra, wb := io.Pipe() - rb, wa := io.Pipe() - - clientStr := newStream(0, ra, wa) - serverStr := newStream(1, rb, wb) + clientStr, serverStr := newStreamPair() // send a foobar from the client n, err := clientStr.Write([]byte("foobar")) require.NoError(t, err) require.Equal(t, 6, n) require.NoError(t, clientStr.CloseWrite()) + // writing after closing should error _, err = clientStr.Write([]byte("foobar")) require.Error(t, err) - //require.False(t, clientDone.Load()) // now read all the data on the server side b, err := io.ReadAll(serverStr) @@ -33,7 +28,6 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { n, err = serverStr.Read(make([]byte, 10)) require.Zero(t, n) require.ErrorIs(t, err, io.EOF) - //require.False(t, serverDone.Load()) // send something back _, err = serverStr.Write([]byte("lorem ipsum")) @@ -49,8 +43,6 @@ func TestStreamSimpleReadWriteClose(t *testing.T) { // stream is only cleaned up on calling Close or Reset clientStr.Close() serverStr.Close() - //require.Eventually(t, func() bool { return clientDone.Load() }, 5*time.Second, 100*time.Millisecond) // Need to call Close for cleanup. Otherwise the FIN_ACK is never read require.NoError(t, serverStr.Close()) - //require.Eventually(t, func() bool { return serverDone.Load() }, 5*time.Second, 100*time.Millisecond) } diff --git a/p2p/transport/memory/transport.go b/p2p/transport/memory/transport.go index 5016e3a7dd..a13c737437 100644 --- a/p2p/transport/memory/transport.go +++ b/p2p/transport/memory/transport.go @@ -3,7 +3,6 @@ package memory import ( "context" "errors" - "io" "sync" "sync/atomic" @@ -13,6 +12,14 @@ import ( "github.com/libp2p/go-libp2p/core/pnet" tpt "github.com/libp2p/go-libp2p/core/transport" ma "github.com/multiformats/go-multiaddr" + mafmt "github.com/multiformats/go-multiaddr-fmt" +) + +var ( + connCounter atomic.Int64 + streamCounter atomic.Int64 + listenerCounter atomic.Int64 + dialMatcher = mafmt.Base(ma.P_MEMORY) ) type hub struct { @@ -84,8 +91,7 @@ type transport struct { mu sync.RWMutex - connID atomic.Int32 - connections map[int32]*conn + connections map[int64]*conn } func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManager) (tpt.Transport, error) { @@ -105,7 +111,7 @@ func NewTransport(privKey ic.PrivKey, psk pnet.PSK, rcmgr network.ResourceManage localPeerID: id, localPrivKey: privKey, localPubKey: privKey.GetPublic(), - connections: make(map[int32]*conn), + connections: make(map[int64]*conn), }, nil } @@ -141,19 +147,16 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, rpid return nil, errors.New("failed to get remote public key") } - ra, wb := io.Pipe() - rb, wa := io.Pipe() - inConnId, outConnId := t.connID.Add(1), t.connID.Add(1) - inStream, outStream := newStream(0, ra, wb), newStream(0, rb, wa) - - l.connCh <- newConnection(inConnId, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil) + inStream, outStream := newStreamPair() + inConn := newConnection(t, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr) + outConn := newConnection(nil, inStream, rpid, raddr, t.localPubKey, t.localPeerID, nil) + l.connCh <- outConn - return newConnection(outConnId, outStream, t.localPeerID, nil, remotePubKey, rpid, raddr), nil + return inConn, nil } func (t *transport) CanDial(addr ma.Multiaddr) bool { - _, exists := memhub.getListener(addr.String()) - return exists + return dialMatcher.Matches(addr) } func (t *transport) Listen(laddr ma.Multiaddr) (tpt.Listener, error) { @@ -184,6 +187,12 @@ func (t *transport) String() string { func (t *transport) Close() error { // TODO: Go trough all listeners and close them memhub.close() + + for _, c := range t.connections { + c.Close() + delete(t.connections, c.id) + } + return nil } diff --git a/p2p/transport/memory/transport_test.go b/p2p/transport/memory/transport_test.go new file mode 100644 index 0000000000..f83f0d1280 --- /dev/null +++ b/p2p/transport/memory/transport_test.go @@ -0,0 +1,70 @@ +package memory + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "io" + "testing" + + ic "github.com/libp2p/go-libp2p/core/crypto" + tpt "github.com/libp2p/go-libp2p/core/transport" + + ma "github.com/multiformats/go-multiaddr" + "github.com/stretchr/testify/require" +) + +func getTransport(t *testing.T) tpt.Transport { + t.Helper() + rsaKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + key, err := ic.UnmarshalRsaPrivateKey(x509.MarshalPKCS1PrivateKey(rsaKey)) + require.NoError(t, err) + tr, err := NewTransport(key, nil, nil) + require.NoError(t, err) + return tr +} + +func TestMemoryProtocol(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() + + protocols := tr.Protocols() + if len(protocols) > 1 { + t.Fatalf("expected at most one protocol, got %v", protocols) + } + + if protocols[0] != ma.P_MEMORY { + t.Fatalf("expected the supported protocol to be memory, got %d", protocols[0]) + } +} + +func TestCanDial(t *testing.T) { + tr := getTransport(t) + defer tr.(io.Closer).Close() + + invalid := []string{ + "/ip4/127.0.0.1/udp/1234", + "/ip4/5.5.5.5/tcp/1234", + "/dns/google.com/udp/443/quic-v1", + "/ip4/127.0.0.1/udp/1234/quic", + } + valid := []string{ + "/memory/1234", + "/memory/1337123", + } + for _, s := range invalid { + invalidAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if tr.CanDial(invalidAddr) { + t.Errorf("didn't expect to be able to dial a non-memory address (%s)", invalidAddr) + } + } + for _, s := range valid { + validAddr, err := ma.NewMultiaddr(s) + require.NoError(t, err) + if !tr.CanDial(validAddr) { + t.Errorf("expected to be able to dial memory address (%s)", validAddr) + } + } +}