diff --git a/tonic-web/src/call.rs b/tonic-web/src/call.rs index 98968048a..5e372349e 100644 --- a/tonic-web/src/call.rs +++ b/tonic-web/src/call.rs @@ -160,7 +160,8 @@ impl GrpcWebCall { impl GrpcWebCall where - B: Body, + B: Body, + B::Data: Buf, B::Error: Error, { // Poll body for data, decoding (e.g. via Base64 if necessary) and returning frames @@ -169,7 +170,7 @@ where fn poll_decode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, Status>>> { + ) -> Poll, Status>>> { match self.encoding { Encoding::Base64 => loop { if let Some(bytes) = self.as_mut().decode_chunk()? { @@ -179,7 +180,9 @@ where let mut this = self.as_mut().project(); match ready!(this.inner.as_mut().poll_frame(cx)) { - Some(Ok(frame)) if frame.is_data() => this.buf.put(frame.into_data().unwrap()), + Some(Ok(frame)) if frame.is_data() => this + .buf + .put(frame.into_data().unwrap_or_else(|_| unreachable!())), Some(Ok(frame)) if frame.is_trailers() => { return Poll::Ready(Some(Err(internal_error( "malformed base64 request has unencoded trailers", @@ -201,19 +204,25 @@ where } }, - Encoding::None => self.project().inner.poll_frame(cx).map_err(internal_error), + Encoding::None => self + .project() + .inner + .poll_frame(cx) + .map_ok(|f| f.map_data(|mut d| d.copy_to_bytes(d.remaining()))) + .map_err(internal_error), } } fn poll_encode( mut self: Pin<&mut Self>, cx: &mut Context<'_>, - ) -> Poll, Status>>> { + ) -> Poll, Status>>> { let mut this = self.as_mut().project(); match ready!(this.inner.as_mut().poll_frame(cx)) { Some(Ok(frame)) if frame.is_data() => { - let mut res = frame.into_data().unwrap(); + let mut data = frame.into_data().unwrap_or_else(|_| unreachable!()); + let mut res = data.copy_to_bytes(data.remaining()); if *this.encoding == Encoding::Base64 { res = crate::util::base64::STANDARD.encode(res).into(); @@ -222,12 +231,14 @@ where Poll::Ready(Some(Ok(Frame::data(res)))) } Some(Ok(frame)) if frame.is_trailers() => { - let trailers = frame.into_trailers().expect("must be trailers"); - let mut frame = make_trailers_frame(trailers); + let trailers = frame.into_trailers().unwrap_or_else(|_| unreachable!()); + let mut res = make_trailers_frame(trailers); + if *this.encoding == Encoding::Base64 { - frame = crate::util::base64::STANDARD.encode(frame).into_bytes(); + res = crate::util::base64::STANDARD.encode(res).into(); } - Poll::Ready(Some(Ok(Frame::data(frame.into())))) + + Poll::Ready(Some(Ok(Frame::data(res)))) } Some(Ok(_)) => Poll::Ready(Some(Err(internal_error("unexpected frame type")))), Some(Err(e)) => Poll::Ready(Some(Err(internal_error(e)))), @@ -238,7 +249,7 @@ where impl Body for GrpcWebCall where - B: Body, + B: Body, B::Error: Error, { type Data = Bytes; @@ -327,7 +338,7 @@ where impl Stream for GrpcWebCall where - B: Body, + B: Body, B::Error: Error, { type Item = Result, Status>; @@ -430,17 +441,17 @@ fn decode_trailers_frame(mut buf: Bytes) -> Result, Status> { Ok(Some(map)) } -fn make_trailers_frame(trailers: HeaderMap) -> Vec { +fn make_trailers_frame(trailers: HeaderMap) -> Bytes { let trailers = encode_trailers(trailers); let len = trailers.len(); assert!(len <= u32::MAX as usize); - let mut frame = Vec::with_capacity(len + FRAME_HEADER_SIZE); - frame.push(GRPC_WEB_TRAILERS_BIT); + let mut frame = BytesMut::with_capacity(len + FRAME_HEADER_SIZE); + frame.put_u8(GRPC_WEB_TRAILERS_BIT); frame.put_u32(len as u32); - frame.extend(trailers); + frame.put_slice(&trailers); - frame + frame.freeze() } /// Search some buffer for grpc-web trailers headers and return