Skip to content

Commit 8dec96e

Browse files
committed
Ensure p2p protocol matches new Starknet spec
1 parent 0c6508c commit 8dec96e

File tree

5 files changed

+131
-23
lines changed

5 files changed

+131
-23
lines changed

p2p/p2p.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
139139
}
140140
}
141141

142-
p2pdht, err := makeDHT(p2phost, peersAddrInfoS)
142+
p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, snNetwork)
143143
if err != nil {
144144
return nil, err
145145
}
@@ -160,9 +160,9 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
160160
return s, nil
161161
}
162162

163-
func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) {
163+
func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, snNetwork *utils.Network) (*dht.IpfsDHT, error) {
164164
return dht.New(context.Background(), p2phost,
165-
dht.ProtocolPrefix(p2pSync.Prefix),
165+
dht.ProtocolPrefix(p2pSync.DHTPrefixPID(snNetwork)),
166166
dht.BootstrapPeers(addrInfos...),
167167
dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod),
168168
dht.Mode(dht.ModeServer),
@@ -250,11 +250,11 @@ func (s *Service) Run(ctx context.Context) error {
250250
}
251251

252252
func (s *Service) setProtocolHandlers() {
253-
s.SetProtocolHandler(p2pSync.HeadersPID(), s.handler.HeadersHandler)
254-
s.SetProtocolHandler(p2pSync.EventsPID(), s.handler.EventsHandler)
255-
s.SetProtocolHandler(p2pSync.TransactionsPID(), s.handler.TransactionsHandler)
256-
s.SetProtocolHandler(p2pSync.ClassesPID(), s.handler.ClassesHandler)
257-
s.SetProtocolHandler(p2pSync.StateDiffPID(), s.handler.StateDiffHandler)
253+
s.SetProtocolHandler(p2pSync.HeadersPID(s.network), s.handler.HeadersHandler)
254+
s.SetProtocolHandler(p2pSync.EventsPID(s.network), s.handler.EventsHandler)
255+
s.SetProtocolHandler(p2pSync.TransactionsPID(s.network), s.handler.TransactionsHandler)
256+
s.SetProtocolHandler(p2pSync.ClassesPID(s.network), s.handler.ClassesHandler)
257+
s.SetProtocolHandler(p2pSync.StateDiffPID(s.network), s.handler.StateDiffHandler)
258258
}
259259

260260
func (s *Service) callAndLogErr(f func() error, msg string) {

p2p/p2p_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ import (
88
"github.com/NethermindEth/juno/p2p"
99
"github.com/NethermindEth/juno/utils"
1010
"github.com/libp2p/go-libp2p/core/peer"
11+
"github.com/libp2p/go-libp2p/core/protocol"
12+
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
1113
"github.com/multiformats/go-multiaddr"
14+
"github.com/stretchr/testify/assert"
1215
"github.com/stretchr/testify/require"
1316
)
1417

@@ -64,3 +67,36 @@ func TestLoadAndPersistPeers(t *testing.T) {
6467
)
6568
require.NoError(t, err)
6669
}
70+
71+
func TestMakeDHTProtocolName(t *testing.T) {
72+
net, err := mocknet.FullMeshLinked(1)
73+
require.NoError(t, err)
74+
testHost := net.Hosts()[0]
75+
76+
testCases := []struct {
77+
name string
78+
network *utils.Network
79+
expected string
80+
}{
81+
{
82+
name: "sepolia network",
83+
network: &utils.Sepolia,
84+
expected: "/starknet/SN_SEPOLIA/sync/kad/1.0.0",
85+
},
86+
{
87+
name: "mainnet network",
88+
network: &utils.Mainnet,
89+
expected: "/starknet/SN_MAIN/sync/kad/1.0.0",
90+
},
91+
}
92+
93+
for _, tc := range testCases {
94+
t.Run(tc.name, func(t *testing.T) {
95+
dht, err := p2p.MakeDHT(testHost, nil, tc.network)
96+
require.NoError(t, err)
97+
98+
protocols := dht.Host().Mux().Protocols()
99+
assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols)
100+
})
101+
}
102+
}

p2p/sync/client.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,22 @@ func (c *Client) RequestBlockHeaders(
104104
ctx context.Context, req *gen.BlockHeadersRequest,
105105
) (iter.Seq[*gen.BlockHeadersResponse], error) {
106106
return requestAndReceiveStream[*gen.BlockHeadersRequest, *gen.BlockHeadersResponse](
107-
ctx, c.newStream, HeadersPID(), req, c.log)
107+
ctx, c.newStream, HeadersPID(c.network), req, c.log)
108108
}
109109

110110
func (c *Client) RequestEvents(ctx context.Context, req *gen.EventsRequest) (iter.Seq[*gen.EventsResponse], error) {
111-
return requestAndReceiveStream[*gen.EventsRequest, *gen.EventsResponse](ctx, c.newStream, EventsPID(), req, c.log)
111+
return requestAndReceiveStream[*gen.EventsRequest, *gen.EventsResponse](ctx, c.newStream, EventsPID(c.network), req, c.log)
112112
}
113113

114114
func (c *Client) RequestClasses(ctx context.Context, req *gen.ClassesRequest) (iter.Seq[*gen.ClassesResponse], error) {
115-
return requestAndReceiveStream[*gen.ClassesRequest, *gen.ClassesResponse](ctx, c.newStream, ClassesPID(), req, c.log)
115+
return requestAndReceiveStream[*gen.ClassesRequest, *gen.ClassesResponse](ctx, c.newStream, ClassesPID(c.network), req, c.log)
116116
}
117117

118118
func (c *Client) RequestStateDiffs(ctx context.Context, req *gen.StateDiffsRequest) (iter.Seq[*gen.StateDiffsResponse], error) {
119-
return requestAndReceiveStream[*gen.StateDiffsRequest, *gen.StateDiffsResponse](ctx, c.newStream, StateDiffPID(), req, c.log)
119+
return requestAndReceiveStream[*gen.StateDiffsRequest, *gen.StateDiffsResponse](ctx, c.newStream, StateDiffPID(c.network), req, c.log)
120120
}
121121

122122
func (c *Client) RequestTransactions(ctx context.Context, req *gen.TransactionsRequest) (iter.Seq[*gen.TransactionsResponse], error) {
123123
return requestAndReceiveStream[*gen.TransactionsRequest, *gen.TransactionsResponse](
124-
ctx, c.newStream, TransactionsPID(), req, c.log)
124+
ctx, c.newStream, TransactionsPID(c.network), req, c.log)
125125
}

p2p/sync/ids.go

+15-10
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,32 @@
11
package sync
22

33
import (
4+
"github.com/NethermindEth/juno/utils"
45
"github.com/libp2p/go-libp2p/core/protocol"
56
)
67

78
const Prefix = "/starknet"
89

9-
func HeadersPID() protocol.ID {
10-
return Prefix + "/headers/0.1.0-rc.0"
10+
func HeadersPID(network *utils.Network) protocol.ID {
11+
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/headers/0.1.0-rc.0")
1112
}
1213

13-
func EventsPID() protocol.ID {
14-
return Prefix + "/events/0.1.0-rc.0"
14+
func EventsPID(network *utils.Network) protocol.ID {
15+
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/events/0.1.0-rc.0")
1516
}
1617

17-
func TransactionsPID() protocol.ID {
18-
return Prefix + "/transactions/0.1.0-rc.0"
18+
func TransactionsPID(network *utils.Network) protocol.ID {
19+
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/transactions/0.1.0-rc.0")
1920
}
2021

21-
func ClassesPID() protocol.ID {
22-
return Prefix + "/classes/0.1.0-rc.0"
22+
func ClassesPID(network *utils.Network) protocol.ID {
23+
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/classes/0.1.0-rc.0")
2324
}
2425

25-
func StateDiffPID() protocol.ID {
26-
return Prefix + "/state_diffs/0.1.0-rc.0"
26+
func StateDiffPID(network *utils.Network) protocol.ID {
27+
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync/state_diffs/0.1.0-rc.0")
28+
}
29+
30+
func DHTPrefixPID(network *utils.Network) protocol.ID {
31+
return protocol.ID(Prefix + "/" + network.L2ChainID + "/sync")
2732
}

p2p/sync/ids_test.go

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package sync
2+
3+
import (
4+
"testing"
5+
6+
"github.com/NethermindEth/juno/utils"
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestProtocolIDs(t *testing.T) {
11+
testCases := []struct {
12+
name string
13+
network *utils.Network
14+
pidFunc func(*utils.Network) string
15+
expected string
16+
}{
17+
{
18+
name: "HeadersPID with SN_MAIN",
19+
network: &utils.Mainnet,
20+
pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) },
21+
expected: "/starknet/SN_MAIN/sync/headers/0.1.0-rc.0",
22+
},
23+
{
24+
name: "EventsPID with SN_MAIN",
25+
network: &utils.Mainnet,
26+
pidFunc: func(n *utils.Network) string { return string(EventsPID(n)) },
27+
expected: "/starknet/SN_MAIN/sync/events/0.1.0-rc.0",
28+
},
29+
{
30+
name: "TransactionsPID with SN_MAIN",
31+
network: &utils.Mainnet,
32+
pidFunc: func(n *utils.Network) string { return string(TransactionsPID(n)) },
33+
expected: "/starknet/SN_MAIN/sync/transactions/0.1.0-rc.0",
34+
},
35+
{
36+
name: "ClassesPID with SN_MAIN",
37+
network: &utils.Mainnet,
38+
pidFunc: func(n *utils.Network) string { return string(ClassesPID(n)) },
39+
expected: "/starknet/SN_MAIN/sync/classes/0.1.0-rc.0",
40+
},
41+
{
42+
name: "StateDiffPID with SN_MAIN",
43+
network: &utils.Mainnet,
44+
pidFunc: func(n *utils.Network) string { return string(StateDiffPID(n)) },
45+
expected: "/starknet/SN_MAIN/sync/state_diffs/0.1.0-rc.0",
46+
},
47+
{
48+
name: "DHTPrefixPID with SN_MAIN",
49+
network: &utils.Mainnet,
50+
pidFunc: func(n *utils.Network) string { return string(DHTPrefixPID(n)) },
51+
expected: "/starknet/SN_MAIN/sync",
52+
},
53+
{
54+
name: "HeadersPID with SN_SEPOLIA",
55+
network: &utils.Sepolia,
56+
pidFunc: func(n *utils.Network) string { return string(HeadersPID(n)) },
57+
expected: "/starknet/SN_SEPOLIA/sync/headers/0.1.0-rc.0",
58+
},
59+
}
60+
61+
for _, tc := range testCases {
62+
t.Run(tc.name, func(t *testing.T) {
63+
result := tc.pidFunc(tc.network)
64+
assert.Equal(t, tc.expected, result)
65+
})
66+
}
67+
}

0 commit comments

Comments
 (0)