Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

codec: check serialization result length #496

Open
wants to merge 7 commits into
base: v0.5.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions benchmark/src/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ impl Generic {

#[inline]
#[allow(clippy::ptr_arg)]
pub fn bin_ser(t: &Vec<u8>, buf: &mut Vec<u8>) {
buf.extend_from_slice(t)
pub fn bin_ser(t: &Vec<u8>, buf: &mut Vec<u8>) -> grpc::Result<()> {
buf.extend_from_slice(t);
Ok(())
}

#[inline]
Expand Down
4 changes: 2 additions & 2 deletions src/call/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ impl Call {
) -> Result<ClientUnaryReceiver<Resp>> {
let call = channel.create_call(method, &opt)?;
let mut payload = vec![];
(method.req_ser())(req, &mut payload);
(method.req_ser())(req, &mut payload)?;
let cq_f = check_run(BatchType::CheckRead, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_start_unary(
call.call,
Expand Down Expand Up @@ -157,7 +157,7 @@ impl Call {
) -> Result<ClientSStreamReceiver<Resp>> {
let call = channel.create_call(method, &opt)?;
let mut payload = vec![];
(method.req_ser())(req, &mut payload);
(method.req_ser())(req, &mut payload)?;
let cq_f = check_run(BatchType::Finish, |ctx, tag| unsafe {
grpc_sys::grpcwrap_call_start_server_streaming(
call.call,
Expand Down
2 changes: 1 addition & 1 deletion src/call/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ impl SinkBase {
}

self.buf.clear();
ser(t, &mut self.buf);
ser(t, &mut self.buf)?;
if flags.get_buffer_hint() && self.send_metadata {
// temporary fix: buffer hint with send meta will not send out any metadata.
flags = flags.buffer_hint(false);
Expand Down
19 changes: 14 additions & 5 deletions src/call/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,20 @@ macro_rules! impl_unary_sink {
}

fn complete(mut self, status: RpcStatus, t: Option<T>) -> $rt {
let data = t.as_ref().map(|t| {
let mut buf = vec![];
(self.ser)(t, &mut buf);
buf
});
let data = match t {
Some(t) => {
let mut buf = vec![];
match (self.ser)(&t, &mut buf) {
Ok(()) => Some(buf),
Err(e) => return $rt {
call: self.call.take().unwrap(),
cq_f: None,
err: Some(e),
}
}
},
None => None,
};

let write_flags = self.write_flags;
let res = self.call.as_mut().unwrap().call(|c| {
Expand Down
29 changes: 21 additions & 8 deletions src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::call::MessageReader;
use crate::error::Result;

pub type DeserializeFn<T> = fn(MessageReader) -> Result<T>;
pub type SerializeFn<T> = fn(&T, &mut Vec<u8>);
pub type SerializeFn<T> = fn(&T, &mut Vec<u8>) -> Result<()>;

/// Defines how to serialize and deserialize between the specialized type and byte slice.
pub struct Marshaller<T> {
Expand All @@ -29,11 +29,18 @@ pub mod pb_codec {
use protobuf::{CodedInputStream, Message};

use super::MessageReader;
use crate::error::Result;
use crate::error::{Error, Result};

#[inline]
pub fn ser<T: Message>(t: &T, buf: &mut Vec<u8>) {
t.write_to_vec(buf).unwrap()
pub fn ser<T: Message>(t: &T, buf: &mut Vec<u8>) -> Result<()> {
t.write_to_vec(buf)?;
if buf.len() <= u32::MAX as usize {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment to explain that the source of this number is from https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md

Ok(())
} else {
Err(Error::Codec(
format!("message is too large: {} > u32::MAX", buf.len()).into(),
))
}
}

#[inline]
Expand All @@ -47,15 +54,21 @@ pub mod pb_codec {

#[cfg(feature = "prost-codec")]
pub mod pr_codec {
use bytes::buf::BufMut;
use prost::Message;

use super::MessageReader;
use crate::error::Result;
use crate::error::{Error, Result};

#[inline]
pub fn ser<M: Message, B: BufMut>(msg: &M, buf: &mut B) {
msg.encode(buf).expect("Writing message to buffer failed");
pub fn ser<M: Message>(msg: &M, buf: &mut Vec<u8>) -> Result<()> {
msg.encode(buf)?;
if buf.len() <= u32::MAX as usize {
Ok(())
} else {
Err(Error::Codec(
format!("message is too large: {} > u32::MAX", buf.len()).into(),
))
}
}

#[inline]
Expand Down
13 changes: 9 additions & 4 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ use std::{error, fmt, result};
use crate::call::RpcStatus;
use crate::grpc_sys::grpc_call_error;

#[cfg(feature = "prost-codec")]
use prost::DecodeError;
#[cfg(feature = "protobuf-codec")]
use protobuf::ProtobufError;

Expand Down Expand Up @@ -64,8 +62,15 @@ impl From<ProtobufError> for Error {
}

#[cfg(feature = "prost-codec")]
impl From<DecodeError> for Error {
fn from(e: DecodeError) -> Error {
impl From<prost::DecodeError> for Error {
fn from(e: prost::DecodeError) -> Error {
Error::Codec(Box::new(e))
}
}

#[cfg(feature = "prost-codec")]
impl From<prost::EncodeError> for Error {
fn from(e: prost::EncodeError) -> Error {
Error::Codec(Box::new(e))
}
}
Expand Down