Skip to content

Commit 26b2c37

Browse files
committed
Allocate persistent connection IDs immediately when generated
Guards against future bugs where multiple (e.g. concurrent) calls to `new_cid` might otherwise lead to an undetected collision.
1 parent c4367ae commit 26b2c37

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

quinn-proto/src/endpoint.rs

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use std::{
2-
collections::HashMap,
2+
collections::{hash_map, HashMap},
33
convert::TryFrom,
44
fmt, iter,
55
net::{IpAddr, SocketAddr},
@@ -324,7 +324,8 @@ impl Endpoint {
324324
let remote_id = RandomConnectionIdGenerator::new(MAX_CID_SIZE).generate_cid();
325325
trace!(initial_dcid = %remote_id);
326326

327-
let loc_cid = self.new_cid();
327+
let ch = ConnectionHandle(self.connections.vacant_key());
328+
let loc_cid = self.new_cid(ch);
328329
let params = TransportParameters::new(
329330
&config.transport,
330331
&self.config,
@@ -336,7 +337,7 @@ impl Endpoint {
336337
.crypto
337338
.start_session(config.version, server_name, &params)?;
338339

339-
let (ch, conn) = self.add_connection(
340+
let conn = self.add_connection(
340341
config.version,
341342
remote_id,
342343
loc_cid,
@@ -349,6 +350,7 @@ impl Endpoint {
349350
tls,
350351
None,
351352
config.transport,
353+
ch,
352354
);
353355
Ok((ch, conn))
354356
}
@@ -361,8 +363,7 @@ impl Endpoint {
361363
) -> ConnectionEvent {
362364
let mut ids = vec![];
363365
for _ in 0..num {
364-
let id = self.new_cid();
365-
self.index.insert_cid(id, ch);
366+
let id = self.new_cid(ch);
366367
let meta = &mut self.connections[ch];
367368
meta.cids_issued += 1;
368369
let sequence = meta.cids_issued;
@@ -376,10 +377,12 @@ impl Endpoint {
376377
ConnectionEvent(ConnectionEventInner::NewIdentifiers(ids, now))
377378
}
378379

379-
fn new_cid(&mut self) -> ConnectionId {
380+
/// Generate a connection ID for `ch`
381+
fn new_cid(&mut self, ch: ConnectionHandle) -> ConnectionId {
380382
loop {
381383
let cid = self.local_cid_generator.generate_cid();
382-
if !self.index.connection_ids.contains_key(&cid) {
384+
if let hash_map::Entry::Vacant(e) = self.index.connection_ids.entry(cid) {
385+
e.insert(ch);
383386
break cid;
384387
}
385388
assert!(self.local_cid_generator.cid_len() > 0);
@@ -423,8 +426,7 @@ impl Endpoint {
423426
return None;
424427
}
425428

426-
let loc_cid = self.new_cid();
427-
let server_config = self.server_config.as_ref().unwrap();
429+
let server_config = self.server_config.as_ref().unwrap().clone();
428430

429431
if self.connections.len() >= server_config.concurrent_connections as usize || self.is_full()
430432
{
@@ -434,7 +436,6 @@ impl Endpoint {
434436
addresses,
435437
crypto,
436438
&src_cid,
437-
&loc_cid,
438439
TransportError::CONNECTION_REFUSED(""),
439440
)));
440441
}
@@ -451,7 +452,6 @@ impl Endpoint {
451452
addresses,
452453
crypto,
453454
&src_cid,
454-
&loc_cid,
455455
TransportError::PROTOCOL_VIOLATION("invalid destination CID length"),
456456
)));
457457
}
@@ -461,6 +461,12 @@ impl Endpoint {
461461
// First Initial
462462
let mut random_bytes = vec![0u8; RetryToken::RANDOM_BYTES_LEN];
463463
self.rng.fill_bytes(&mut random_bytes);
464+
// The peer will use this as the DCID of its following Initials. Initial DCIDs are
465+
// looked up separately from Handshake/Data DCIDs, so there is no risk of collision
466+
// with established connections. In the unlikely event that a collision occurs
467+
// between two connections in the initial phase, both will fail fast and may be
468+
// retried by the application layer.
469+
let loc_cid = self.local_cid_generator.generate_cid();
464470

465471
let token = RetryToken {
466472
orig_dst_cid: dst_cid,
@@ -508,7 +514,6 @@ impl Endpoint {
508514
addresses,
509515
crypto,
510516
&src_cid,
511-
&loc_cid,
512517
TransportError::INVALID_TOKEN(""),
513518
)));
514519
}
@@ -517,7 +522,8 @@ impl Endpoint {
517522
(None, dst_cid)
518523
};
519524

520-
let server_config = server_config.clone();
525+
let ch = ConnectionHandle(self.connections.vacant_key());
526+
let loc_cid = self.new_cid(ch);
521527
let mut params = TransportParameters::new(
522528
&server_config.transport,
523529
&self.config,
@@ -531,7 +537,7 @@ impl Endpoint {
531537

532538
let tls = server_config.crypto.clone().start_session(version, &params);
533539
let transport_config = server_config.transport.clone();
534-
let (ch, mut conn) = self.add_connection(
540+
let mut conn = self.add_connection(
535541
version,
536542
dst_cid,
537543
loc_cid,
@@ -541,6 +547,7 @@ impl Endpoint {
541547
tls,
542548
Some(server_config),
543549
transport_config,
550+
ch,
544551
);
545552
if dst_cid.len() != 0 {
546553
self.index.insert_initial(dst_cid, ch);
@@ -554,9 +561,9 @@ impl Endpoint {
554561
debug!("handshake failed: {}", e);
555562
self.handle_event(ch, EndpointEvent(EndpointEventInner::Drained));
556563
if let ConnectionError::TransportError(e) = e {
557-
Some(DatagramEvent::Response(self.initial_close(
558-
version, addresses, crypto, &src_cid, &loc_cid, e,
559-
)))
564+
Some(DatagramEvent::Response(
565+
self.initial_close(version, addresses, crypto, &src_cid, e),
566+
))
560567
} else {
561568
None
562569
}
@@ -575,7 +582,8 @@ impl Endpoint {
575582
tls: Box<dyn crypto::Session>,
576583
server_config: Option<Arc<ServerConfig>>,
577584
transport_config: Arc<TransportConfig>,
578-
) -> (ConnectionHandle, Connection) {
585+
ch: ConnectionHandle,
586+
) -> Connection {
579587
let conn = Connection::new(
580588
self.config.clone(),
581589
server_config,
@@ -599,11 +607,11 @@ impl Endpoint {
599607
addresses,
600608
reset_token: None,
601609
});
610+
debug_assert_eq!(id, ch.0, "connection handle allocation out of sync");
602611

603-
let ch = ConnectionHandle(id);
604612
self.index.insert_conn(addresses, loc_cid, ch);
605613

606-
(ch, conn)
614+
conn
607615
}
608616

609617
fn initial_close(
@@ -612,13 +620,16 @@ impl Endpoint {
612620
addresses: FourTuple,
613621
crypto: &Keys,
614622
remote_id: &ConnectionId,
615-
local_id: &ConnectionId,
616623
reason: TransportError,
617624
) -> Transmit {
625+
// We don't need to worry about CID collisions in initial closes because the peer
626+
// shouldn't respond, and if it does, and the CID collides, we'll just drop the
627+
// unexpected response.
628+
let local_id = self.local_cid_generator.generate_cid();
618629
let number = PacketNumber::U8(0);
619630
let header = Header::Initial {
620631
dst_cid: *remote_id,
621-
src_cid: *local_id,
632+
src_cid: local_id,
622633
number,
623634
token: Bytes::new(),
624635
version,
@@ -736,11 +747,6 @@ impl ConnectionIndex {
736747
}
737748
}
738749

739-
/// Add a new CID to an existing connection
740-
fn insert_cid(&mut self, dst_cid: ConnectionId, connection: ConnectionHandle) {
741-
self.connection_ids.insert(dst_cid, connection);
742-
}
743-
744750
/// Discard a connection ID
745751
fn retire(&mut self, dst_cid: &ConnectionId) {
746752
self.connection_ids.remove(dst_cid);

0 commit comments

Comments
 (0)