@@ -106,8 +106,10 @@ use hyper_util::rt::TokioIo;
106106use sha1:: { Digest , Sha1 } ;
107107use std:: {
108108 borrow:: Cow ,
109+ collections:: BTreeSet ,
109110 future:: Future ,
110111 pin:: Pin ,
112+ str,
111113 task:: { ready, Context , Poll } ,
112114} ;
113115use tokio_tungstenite:: {
@@ -137,7 +139,7 @@ pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
137139 sec_websocket_key : Option < HeaderValue > ,
138140 on_upgrade : hyper:: upgrade:: OnUpgrade ,
139141 on_failed_upgrade : F ,
140- sec_websocket_protocol : Option < HeaderValue > ,
142+ sec_websocket_protocol : BTreeSet < HeaderValue > ,
141143}
142144
143145impl < F > std:: fmt:: Debug for WebSocketUpgrade < F > {
@@ -241,26 +243,23 @@ impl<F> WebSocketUpgrade<F> {
241243 I : IntoIterator ,
242244 I :: Item : Into < Cow < ' static , str > > ,
243245 {
244- if let Some ( req_protocols) = self
245- . sec_websocket_protocol
246- . as_ref ( )
247- . and_then ( |p| p. to_str ( ) . ok ( ) )
248- {
249- self . protocol = protocols
250- . into_iter ( )
251- // FIXME: This will often allocate a new `String` and so is less efficient than it
252- // could be. But that can't be fixed without breaking changes to the public API.
253- . map ( Into :: into)
254- . find ( |protocol| {
255- req_protocols
256- . split ( ',' )
257- . any ( |req_protocol| req_protocol. trim ( ) == protocol)
258- } )
259- . map ( |protocol| match protocol {
260- Cow :: Owned ( s) => HeaderValue :: from_str ( & s) . unwrap ( ) ,
261- Cow :: Borrowed ( s) => HeaderValue :: from_static ( s) ,
262- } ) ;
263- }
246+ self . protocol = protocols
247+ . into_iter ( )
248+ . map ( Into :: into)
249+ . find ( |proto| {
250+ // FIXME: When https://github.com/hyperium/http/pull/814
251+ // is merged + released, we can look use
252+ // `contains(proto.as_bytes())` without converting
253+ // to `HeaderValue` first.
254+ let Ok ( proto) = HeaderValue :: from_str ( proto) else {
255+ return false ;
256+ } ;
257+ self . sec_websocket_protocol . contains ( & proto)
258+ } )
259+ . map ( |protocol| match protocol {
260+ Cow :: Owned ( s) => HeaderValue :: from_str ( & s) . unwrap ( ) ,
261+ Cow :: Borrowed ( s) => HeaderValue :: from_static ( s) ,
262+ } ) ;
264263
265264 self
266265 }
@@ -276,13 +275,8 @@ impl<F> WebSocketUpgrade<F> {
276275 /// ```
277276 ///
278277 /// this method returns an iterator yielding `"soap"` and `"wamp"`.
279- pub fn requested_protocols ( & self ) -> impl Iterator < Item = & str > {
280- self . sec_websocket_protocol
281- . as_ref ( )
282- . and_then ( |p| p. to_str ( ) . ok ( ) )
283- . into_iter ( )
284- . flat_map ( |s| s. split ( ',' ) )
285- . map ( |s| s. trim ( ) )
278+ pub fn requested_protocols ( & self ) -> impl Iterator < Item = & HeaderValue > {
279+ self . sec_websocket_protocol . iter ( )
286280 }
287281
288282 /// Set the chosen WebSocket subprotocol.
@@ -500,7 +494,16 @@ where
500494 . remove :: < hyper:: upgrade:: OnUpgrade > ( )
501495 . ok_or ( ConnectionNotUpgradable ) ?;
502496
503- let sec_websocket_protocol = parts. headers . get ( header:: SEC_WEBSOCKET_PROTOCOL ) . cloned ( ) ;
497+ let sec_websocket_protocol = parts
498+ . headers
499+ . get_all ( header:: SEC_WEBSOCKET_PROTOCOL )
500+ . iter ( )
501+ . flat_map ( |val| val. as_bytes ( ) . split ( |& b| b == b',' ) )
502+ . map ( |proto| {
503+ HeaderValue :: from_bytes ( proto. trim_ascii ( ) )
504+ . expect ( "substring of HeaderValue is valid HeaderValue" )
505+ } )
506+ . collect ( ) ;
504507
505508 Ok ( Self {
506509 config : Default :: default ( ) ,
0 commit comments