Skip to content

Commit 4c09ea7

Browse files
authored
Rewrite sec-websocket-protocol handling (#3620)
1 parent 309dc56 commit 4c09ea7

File tree

5 files changed

+40
-32
lines changed

5 files changed

+40
-32
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ name: CI
22

33
env:
44
CARGO_TERM_COLOR: always
5-
MSRV: '1.78'
5+
MSRV: '1.80'
66

77
on:
88
push:

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ members = ["axum", "axum-*"]
33
resolver = "2"
44

55
[workspace.package]
6-
rust-version = "1.78"
6+
rust-version = "1.80"
77

88
[workspace.lints.rust]
99
unsafe_code = "forbid"

axum/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
(because it was already never terminating if that method wasn't used) ([#3601])
1717
- **added:** New `ListenerExt::limit_connections` allows limiting concurrent `axum::serve` connections ([#3489])
1818
- **added:** `MethodRouter::method_filter` ([#3586])
19+
- **added:** `WebSocketUpgrade::{requested_protocols, set_selected_protocol}` for more
20+
flexible subprotocol selection ([#3597])
1921
- **changed:** `serve` has an additional generic argument and can now work with any response body
2022
type, not just `axum::body::Body` ([#3205])
23+
- **changed:** Update minimum rust version to 1.80 ([#3620])
2124

2225
[#3158]: https://github.com/tokio-rs/axum/pull/3158
2326
[#3261]: https://github.com/tokio-rs/axum/pull/3261
@@ -26,6 +29,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2629
[#3601]: https://github.com/tokio-rs/axum/pull/3601
2730
[#3489]: https://github.com/tokio-rs/axum/pull/3489
2831
[#3586]: https://github.com/tokio-rs/axum/pull/3586
32+
[#3597]: https://github.com/tokio-rs/axum/pull/3597
33+
[#3620]: https://github.com/tokio-rs/axum/pull/3620
2934

3035
# 0.8.8
3136

axum/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in
111111

112112
## Minimum supported Rust version
113113

114-
axum's MSRV is 1.78.
114+
axum's MSRV is 1.80.
115115

116116
## Examples
117117

axum/src/extract/ws.rs

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,10 @@ use hyper_util::rt::TokioIo;
106106
use sha1::{Digest, Sha1};
107107
use std::{
108108
borrow::Cow,
109+
collections::BTreeSet,
109110
future::Future,
110111
pin::Pin,
112+
str,
111113
task::{ready, Context, Poll},
112114
};
113115
use tokio_tungstenite::{
@@ -137,7 +139,7 @@ pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
137139
sec_websocket_key: Option<HeaderValue>,
138140
on_upgrade: hyper::upgrade::OnUpgrade,
139141
on_failed_upgrade: F,
140-
sec_websocket_protocol: Option<HeaderValue>,
142+
sec_websocket_protocol: BTreeSet<HeaderValue>,
141143
}
142144

143145
impl<F> std::fmt::Debug for WebSocketUpgrade<F> {
@@ -241,26 +243,23 @@ impl<F> WebSocketUpgrade<F> {
241243
I: IntoIterator,
242244
I::Item: Into<Cow<'static, str>>,
243245
{
244-
if let Some(req_protocols) = self
245-
.sec_websocket_protocol
246-
.as_ref()
247-
.and_then(|p| p.to_str().ok())
248-
{
249-
self.protocol = protocols
250-
.into_iter()
251-
// FIXME: This will often allocate a new `String` and so is less efficient than it
252-
// could be. But that can't be fixed without breaking changes to the public API.
253-
.map(Into::into)
254-
.find(|protocol| {
255-
req_protocols
256-
.split(',')
257-
.any(|req_protocol| req_protocol.trim() == protocol)
258-
})
259-
.map(|protocol| match protocol {
260-
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
261-
Cow::Borrowed(s) => HeaderValue::from_static(s),
262-
});
263-
}
246+
self.protocol = protocols
247+
.into_iter()
248+
.map(Into::into)
249+
.find(|proto| {
250+
// FIXME: When https://github.com/hyperium/http/pull/814
251+
// is merged + released, we can look use
252+
// `contains(proto.as_bytes())` without converting
253+
// to `HeaderValue` first.
254+
let Ok(proto) = HeaderValue::from_str(proto) else {
255+
return false;
256+
};
257+
self.sec_websocket_protocol.contains(&proto)
258+
})
259+
.map(|protocol| match protocol {
260+
Cow::Owned(s) => HeaderValue::from_str(&s).unwrap(),
261+
Cow::Borrowed(s) => HeaderValue::from_static(s),
262+
});
264263

265264
self
266265
}
@@ -276,13 +275,8 @@ impl<F> WebSocketUpgrade<F> {
276275
/// ```
277276
///
278277
/// this method returns an iterator yielding `"soap"` and `"wamp"`.
279-
pub fn requested_protocols(&self) -> impl Iterator<Item = &str> {
280-
self.sec_websocket_protocol
281-
.as_ref()
282-
.and_then(|p| p.to_str().ok())
283-
.into_iter()
284-
.flat_map(|s| s.split(','))
285-
.map(|s| s.trim())
278+
pub fn requested_protocols(&self) -> impl Iterator<Item = &HeaderValue> {
279+
self.sec_websocket_protocol.iter()
286280
}
287281

288282
/// Set the chosen WebSocket subprotocol.
@@ -500,7 +494,16 @@ where
500494
.remove::<hyper::upgrade::OnUpgrade>()
501495
.ok_or(ConnectionNotUpgradable)?;
502496

503-
let sec_websocket_protocol = parts.headers.get(header::SEC_WEBSOCKET_PROTOCOL).cloned();
497+
let sec_websocket_protocol = parts
498+
.headers
499+
.get_all(header::SEC_WEBSOCKET_PROTOCOL)
500+
.iter()
501+
.flat_map(|val| val.as_bytes().split(|&b| b == b','))
502+
.map(|proto| {
503+
HeaderValue::from_bytes(proto.trim_ascii())
504+
.expect("substring of HeaderValue is valid HeaderValue")
505+
})
506+
.collect();
504507

505508
Ok(Self {
506509
config: Default::default(),

0 commit comments

Comments
 (0)