Skip to content

Commit 4c6a03d

Browse files
committed
Ensure protocol naming matches latest Starknet p2p spec
Use dht.ProtocolExtension to set chainID in protocol ID format for DHT. Add test to verify protocol ID format for different networks: - /starknet/SN_SEPOLIA/kad/1.0.0 for Sepolia - /starknet/SN_MAIN/kad/1.0.0 for Mainnet The change ensures that DHT protocol follow latest Starknet specification.
1 parent 4ff174d commit 4c6a03d

File tree

6 files changed

+132
-24
lines changed

6 files changed

+132
-24
lines changed

Diff for: .github/workflows/juno-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,4 @@ jobs:
6565
with:
6666
token: ${{ secrets.CODECOV_TOKEN }}
6767
fail_ci_if_error: true
68-
files: coverage.out
68+
files: coverage.out

Diff for: p2p/p2p.go

+8-8
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
138138
}
139139
}
140140

141-
p2pdht, err := makeDHT(p2phost, peersAddrInfoS)
141+
p2pdht, err := MakeDHT(p2phost, peersAddrInfoS, snNetwork.L2ChainID)
142142
if err != nil {
143143
return nil, err
144144
}
@@ -159,9 +159,9 @@ func NewWithHost(p2phost host.Host, peers string, feederNode bool, bc *blockchai
159159
return s, nil
160160
}
161161

162-
func makeDHT(p2phost host.Host, addrInfos []peer.AddrInfo) (*dht.IpfsDHT, error) {
162+
func MakeDHT(p2phost host.Host, addrInfos []peer.AddrInfo, chainID string) (*dht.IpfsDHT, error) {
163163
return dht.New(context.Background(), p2phost,
164-
dht.ProtocolPrefix(starknet.Prefix),
164+
dht.ProtocolPrefix(starknet.ChainPID(chainID)),
165165
dht.BootstrapPeers(addrInfos...),
166166
dht.RoutingTableRefreshPeriod(routingTableRefreshPeriod),
167167
dht.Mode(dht.ModeServer),
@@ -249,11 +249,11 @@ func (s *Service) Run(ctx context.Context) error {
249249
}
250250

251251
func (s *Service) setProtocolHandlers() {
252-
s.SetProtocolHandler(starknet.HeadersPID(), s.handler.HeadersHandler)
253-
s.SetProtocolHandler(starknet.EventsPID(), s.handler.EventsHandler)
254-
s.SetProtocolHandler(starknet.TransactionsPID(), s.handler.TransactionsHandler)
255-
s.SetProtocolHandler(starknet.ClassesPID(), s.handler.ClassesHandler)
256-
s.SetProtocolHandler(starknet.StateDiffPID(), s.handler.StateDiffHandler)
252+
s.SetProtocolHandler(starknet.HeadersPID(s.network.L2ChainID), s.handler.HeadersHandler)
253+
s.SetProtocolHandler(starknet.EventsPID(s.network.L2ChainID), s.handler.EventsHandler)
254+
s.SetProtocolHandler(starknet.TransactionsPID(s.network.L2ChainID), s.handler.TransactionsHandler)
255+
s.SetProtocolHandler(starknet.ClassesPID(s.network.L2ChainID), s.handler.ClassesHandler)
256+
s.SetProtocolHandler(starknet.StateDiffPID(s.network.L2ChainID), s.handler.StateDiffHandler)
257257
}
258258

259259
func (s *Service) callAndLogErr(f func() error, msg string) {

Diff for: p2p/p2p_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ import (
88
"github.com/NethermindEth/juno/p2p"
99
"github.com/NethermindEth/juno/utils"
1010
"github.com/libp2p/go-libp2p/core/peer"
11+
"github.com/libp2p/go-libp2p/core/protocol"
12+
mocknet "github.com/libp2p/go-libp2p/p2p/net/mock"
1113
"github.com/multiformats/go-multiaddr"
14+
"github.com/stretchr/testify/assert"
1215
"github.com/stretchr/testify/require"
1316
)
1417

@@ -64,3 +67,36 @@ func TestLoadAndPersistPeers(t *testing.T) {
6467
)
6568
require.NoError(t, err)
6669
}
70+
71+
func TestMakeDHTProtocolName(t *testing.T) {
72+
net, err := mocknet.FullMeshLinked(1)
73+
require.NoError(t, err)
74+
testHost := net.Hosts()[0]
75+
76+
testCases := []struct {
77+
name string
78+
network *utils.Network
79+
expected string
80+
}{
81+
{
82+
name: "sepolia network",
83+
network: &utils.Sepolia,
84+
expected: "/starknet/SN_SEPOLIA/sync/kad/1.0.0",
85+
},
86+
{
87+
name: "mainnet network",
88+
network: &utils.Mainnet,
89+
expected: "/starknet/SN_MAIN/sync/kad/1.0.0",
90+
},
91+
}
92+
93+
for _, tc := range testCases {
94+
t.Run(tc.name, func(t *testing.T) {
95+
dht, err := p2p.MakeDHT(testHost, nil, tc.network.L2ChainID)
96+
require.NoError(t, err)
97+
98+
protocols := dht.Host().Mux().Protocols()
99+
assert.Contains(t, protocols, protocol.ID(tc.expected), "protocol list: %v", protocols)
100+
})
101+
}
102+
}

Diff for: p2p/starknet/client.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -104,22 +104,24 @@ func (c *Client) RequestBlockHeaders(
104104
ctx context.Context, req *spec.BlockHeadersRequest,
105105
) (iter.Seq[*spec.BlockHeadersResponse], error) {
106106
return requestAndReceiveStream[*spec.BlockHeadersRequest, *spec.BlockHeadersResponse](
107-
ctx, c.newStream, HeadersPID(), req, c.log)
107+
ctx, c.newStream, HeadersPID(c.network.L2ChainID), req, c.log)
108108
}
109109

110110
func (c *Client) RequestEvents(ctx context.Context, req *spec.EventsRequest) (iter.Seq[*spec.EventsResponse], error) {
111-
return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(), req, c.log)
111+
return requestAndReceiveStream[*spec.EventsRequest, *spec.EventsResponse](ctx, c.newStream, EventsPID(c.network.L2ChainID), req, c.log)
112112
}
113113

114114
func (c *Client) RequestClasses(ctx context.Context, req *spec.ClassesRequest) (iter.Seq[*spec.ClassesResponse], error) {
115-
return requestAndReceiveStream[*spec.ClassesRequest, *spec.ClassesResponse](ctx, c.newStream, ClassesPID(), req, c.log)
115+
return requestAndReceiveStream[*spec.ClassesRequest, *spec.ClassesResponse](ctx, c.newStream, ClassesPID(c.network.L2ChainID), req, c.log)
116116
}
117117

118118
func (c *Client) RequestStateDiffs(ctx context.Context, req *spec.StateDiffsRequest) (iter.Seq[*spec.StateDiffsResponse], error) {
119-
return requestAndReceiveStream[*spec.StateDiffsRequest, *spec.StateDiffsResponse](ctx, c.newStream, StateDiffPID(), req, c.log)
119+
return requestAndReceiveStream[*spec.StateDiffsRequest, *spec.StateDiffsResponse](
120+
ctx, c.newStream, StateDiffPID(c.network.L2ChainID), req, c.log,
121+
)
120122
}
121123

122124
func (c *Client) RequestTransactions(ctx context.Context, req *spec.TransactionsRequest) (iter.Seq[*spec.TransactionsResponse], error) {
123125
return requestAndReceiveStream[*spec.TransactionsRequest, *spec.TransactionsResponse](
124-
ctx, c.newStream, TransactionsPID(), req, c.log)
126+
ctx, c.newStream, TransactionsPID(c.network.L2ChainID), req, c.log)
125127
}

Diff for: p2p/starknet/ids.go

+14-10
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,26 @@ import (
66

77
const Prefix = "/starknet"
88

9-
func HeadersPID() protocol.ID {
10-
return Prefix + "/headers/0.1.0-rc.0"
9+
func HeadersPID(chainID string) protocol.ID {
10+
return protocol.ID(Prefix + "/" + chainID + "/sync/headers/0.1.0-rc.0")
1111
}
1212

13-
func EventsPID() protocol.ID {
14-
return Prefix + "/events/0.1.0-rc.0"
13+
func EventsPID(chainID string) protocol.ID {
14+
return protocol.ID(Prefix + "/" + chainID + "/sync/events/0.1.0-rc.0")
1515
}
1616

17-
func TransactionsPID() protocol.ID {
18-
return Prefix + "/transactions/0.1.0-rc.0"
17+
func TransactionsPID(chainID string) protocol.ID {
18+
return protocol.ID(Prefix + "/" + chainID + "/sync/transactions/0.1.0-rc.0")
1919
}
2020

21-
func ClassesPID() protocol.ID {
22-
return Prefix + "/classes/0.1.0-rc.0"
21+
func ClassesPID(chainID string) protocol.ID {
22+
return protocol.ID(Prefix + "/" + chainID + "/sync/classes/0.1.0-rc.0")
2323
}
2424

25-
func StateDiffPID() protocol.ID {
26-
return Prefix + "/state_diffs/0.1.0-rc.0"
25+
func StateDiffPID(chainID string) protocol.ID {
26+
return protocol.ID(Prefix + "/" + chainID + "/sync/state_diffs/0.1.0-rc.0")
27+
}
28+
29+
func ChainPID(chainID string) protocol.ID {
30+
return protocol.ID(Prefix + "/" + chainID + "/sync")
2731
}

Diff for: p2p/starknet/ids_test.go

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package starknet
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestProtocolIDs(t *testing.T) {
10+
testCases := []struct {
11+
name string
12+
chainID string
13+
pidFunc func(string) string
14+
expected string
15+
}{
16+
{
17+
name: "HeadersPID with SN_MAIN",
18+
chainID: "SN_MAIN",
19+
pidFunc: func(c string) string { return string(HeadersPID(c)) },
20+
expected: "/starknet/SN_MAIN/sync/headers/0.1.0-rc.0",
21+
},
22+
{
23+
name: "EventsPID with SN_MAIN",
24+
chainID: "SN_MAIN",
25+
pidFunc: func(c string) string { return string(EventsPID(c)) },
26+
expected: "/starknet/SN_MAIN/sync/events/0.1.0-rc.0",
27+
},
28+
{
29+
name: "TransactionsPID with SN_MAIN",
30+
chainID: "SN_MAIN",
31+
pidFunc: func(c string) string { return string(TransactionsPID(c)) },
32+
expected: "/starknet/SN_MAIN/sync/transactions/0.1.0-rc.0",
33+
},
34+
{
35+
name: "ClassesPID with SN_MAIN",
36+
chainID: "SN_MAIN",
37+
pidFunc: func(c string) string { return string(ClassesPID(c)) },
38+
expected: "/starknet/SN_MAIN/sync/classes/0.1.0-rc.0",
39+
},
40+
{
41+
name: "StateDiffPID with SN_MAIN",
42+
chainID: "SN_MAIN",
43+
pidFunc: func(c string) string { return string(StateDiffPID(c)) },
44+
expected: "/starknet/SN_MAIN/sync/state_diffs/0.1.0-rc.0",
45+
},
46+
{
47+
name: "ChainPID with SN_MAIN",
48+
chainID: "SN_MAIN",
49+
pidFunc: func(c string) string { return string(ChainPID(c)) },
50+
expected: "/starknet/SN_MAIN/sync",
51+
},
52+
{
53+
name: "HeadersPID with SN_SEPOLIA",
54+
chainID: "SN_SEPOLIA",
55+
pidFunc: func(c string) string { return string(HeadersPID(c)) },
56+
expected: "/starknet/SN_SEPOLIA/sync/headers/0.1.0-rc.0",
57+
},
58+
}
59+
60+
for _, tc := range testCases {
61+
t.Run(tc.name, func(t *testing.T) {
62+
result := tc.pidFunc(tc.chainID)
63+
assert.Equal(t, tc.expected, result)
64+
})
65+
}
66+
}

0 commit comments

Comments
 (0)