Skip to content

Commit d713fd4

Browse files
committed
Allocate persistent connection IDs immediately when generated
Guards against future bugs where multiple (e.g. concurrent) calls to `new_cid` might otherwise lead to an undetected collision.
1 parent 114a991 commit d713fd4

File tree

1 file changed

+31
-25
lines changed

1 file changed

+31
-25
lines changed

quinn-proto/src/endpoint.rs

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::{
2-
collections::HashMap,
2+
collections::{hash_map, HashMap},
33
convert::TryFrom,
44
fmt, iter,
55
net::{IpAddr, SocketAddr},
@@ -324,7 +324,8 @@ impl Endpoint {
324324
let remote_id = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid();
325325
trace!(initial_dcid = %remote_id);
326326

327-
let loc_cid = self.new_cid();
327+
let ch = ConnectionHandle(self.connections.vacant_key());
328+
let loc_cid = self.new_cid(ch);
328329
let params = TransportParameters::new(
329330
&config.transport,
330331
&self.config,
@@ -336,7 +337,8 @@ impl Endpoint {
336337
.crypto
337338
.start_session(config.version, server_name, &params)?;
338339

339-
let (ch, conn) = self.add_connection(
340+
let conn = self.add_connection(
341+
ch,
340342
config.version,
341343
remote_id,
342344
loc_cid,
@@ -361,8 +363,7 @@ impl Endpoint {
361363
) -> ConnectionEvent {
362364
let mut ids = vec![];
363365
for _ in 0..num {
364-
let id = self.new_cid();
365-
self.index.insert_cid(id, ch);
366+
let id = self.new_cid(ch);
366367
let meta = &mut self.connections[ch];
367368
meta.cids_issued += 1;
368369
let sequence = meta.cids_issued;
@@ -376,10 +377,12 @@ impl Endpoint {
376377
ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now))
377378
}
378379

379-
fn new_cid(&mut self) -> ConnectionId {
380+
/// Generate a connection ID for `ch`
381+
fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId {
380382
loop {
381383
let cid = self.local_cid_generator.generate_cid();
382-
if !self.index.connection_ids.contains_key(&cid) {
384+
if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
385+
e.insert(ch);
383386
break cid;
384387
}
385388
assert!(self.local_cid_generator.cid_len() > 0);
@@ -423,8 +426,7 @@ impl Endpoint {
423426
return None;
424427
}
425428

426-
let loc_cid = self.new_cid();
427-
let server_config = self.server_config.as_ref().unwrap();
429+
let server_config = self.server_config.as_ref().unwrap().clone();
428430

429431
if self.connections.len() >= server_config.concurrent_connections as usize || self.is_full()
430432
{
@@ -434,7 +436,6 @@ impl Endpoint {
434436
addresses,
435437
crypto,
436438
&src_cid,
437-
&loc_cid,
438439
TransportError::CONNECTION_REFUSED(""),
439440
)));
440441
}
@@ -451,7 +452,6 @@ impl Endpoint {
451452
addresses,
452453
crypto,
453454
&src_cid,
454-
&loc_cid,
455455
TransportError::PROTOCOL_VIOLATION("invalid destination CID length"),
456456
)));
457457
}
@@ -461,6 +461,12 @@ impl Endpoint {
461461
// First Initial
462462
let mut random_bytes = vec![0u8; RetryToken::RANDOM_BYTES_LEN];
463463
self.rng.fill_bytes(&mut random_bytes);
464+
// The peer will use this as the DCID of its following Initials. Initial DCIDs are
465+
// looked up separately from Handshake/Data DCIDs, so there is no risk of collision
466+
// with established connections. In the unlikely event that a collision occurs
467+
// between two connections in the initial phase, both will fail fast and may be
468+
// retried by the application layer.
469+
let loc_cid = self.local_cid_generator.generate_cid();
464470

465471
let token = RetryToken {
466472
orig_dst_cid: dst_cid,
@@ -508,7 +514,6 @@ impl Endpoint {
508514
addresses,
509515
crypto,
510516
&src_cid,
511-
&loc_cid,
512517
TransportError::INVALID_TOKEN(""),
513518
)));
514519
}
@@ -517,7 +522,8 @@ impl Endpoint {
517522
(None, dst_cid)
518523
};
519524

520-
let server_config = server_config.clone();
525+
let ch = ConnectionHandle(self.connections.vacant_key());
526+
let loc_cid = self.new_cid(ch);
521527
let mut params = TransportParameters::new(
522528
&server_config.transport,
523529
&self.config,
@@ -531,7 +537,8 @@ impl Endpoint {
531537

532538
let tls = server_config.crypto.clone().start_session(version, &params);
533539
let transport_config = server_config.transport.clone();
534-
let (ch, mut conn) = self.add_connection(
540+
let mut conn = self.add_connection(
541+
ch,
535542
version,
536543
dst_cid,
537544
loc_cid,
@@ -555,7 +562,7 @@ impl Endpoint {
555562
self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
556563
match e {
557564
ConnectionError::TransportError(e) => Some(DatagramEvent::Response(
558-
self.initial_close(version, addresses, crypto, &src_cid, &loc_cid, e),
565+
self.initial_close(version, addresses, crypto, &src_cid, e),
559566
)),
560567
_ => None,
561568
}
@@ -565,6 +572,7 @@ impl Endpoint {
565572

566573
fn add_connection(
567574
&mut self,
575+
ch: ConnectionHandle,
568576
version: u32,
569577
init_cid: ConnectionId,
570578
loc_cid: ConnectionId,
@@ -574,7 +582,7 @@ impl Endpoint {
574582
tls: Box<dyn crypto::Session>,
575583
server_config: Option<Arc<ServerConfig>>,
576584
transport_config: Arc<TransportConfig>,
577-
) -> (ConnectionHandle, Connection) {
585+
) -> Connection {
578586
let conn = Connection::new(
579587
self.config.clone(),
580588
server_config,
@@ -598,11 +606,11 @@ impl Endpoint {
598606
addresses,
599607
reset_token: None,
600608
});
609+
debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");
601610

602-
let ch = ConnectionHandle(id);
603611
self.index.insert_conn(addresses, loc_cid, ch);
604612

605-
(ch, conn)
613+
conn
606614
}
607615

608616
fn initial_close(
@@ -611,13 +619,16 @@ impl Endpoint {
611619
addresses: FourTuple,
612620
crypto: &Keys,
613621
remote_id: &ConnectionId,
614-
local_id: &ConnectionId,
615622
reason: TransportError,
616623
) -> Transmit {
624+
// We don't need to worry about CID collisions in initial closes because the peer
625+
// shouldn't respond, and if it does, and the CID collides, we'll just drop the
626+
// unexpected response.
627+
let local_id = self.local_cid_generator.generate_cid();
617628
let number = PacketNumber::U8(0);
618629
let header = Header::Initial {
619630
dst_cid: *remote_id,
620-
src_cid: *local_id,
631+
src_cid: local_id,
621632
number,
622633
token: Bytes::new(),
623634
version,
@@ -735,11 +746,6 @@ impl ConnectionIndex {
735746
}
736747
}
737748

738-
/// Add a new CID to an existing connection
739-
fn insert_cid(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
740-
self.connection_ids.insert(dst_cid, connection);
741-
}
742-
743749
/// Discard a connection ID
744750
fn retire(&mut self, dst_cid: &ConnectionId) {
745751
self.connection_ids.remove(dst_cid);

0 commit comments

Comments
 (0)