Skip to content

Commit

Permalink
Implement SerializeBytes for TlsByteVecUX types (#1133)
Browse files Browse the repository at this point in the history
Co-authored-by: Franziskus Kiefer <[email protected]>
  • Loading branch information
imor and franziskuskiefer authored Aug 10, 2023
1 parent 3e65c66 commit 90a4324
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 33 deletions.
91 changes: 59 additions & 32 deletions tls_codec/src/tls_vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use serde::ser::SerializeStruct;
use std::io::{Read, Write};
use zeroize::Zeroize;

use crate::{Deserialize, DeserializeBytes, Error, Serialize, Size};
use crate::{Deserialize, DeserializeBytes, Error, Serialize, SerializeBytes, Size};

macro_rules! impl_size {
($self:ident, $size:ty, $name:ident, $len_len:literal) => {
Expand Down Expand Up @@ -129,38 +129,16 @@ macro_rules! impl_serialize {
fn serialize<W: Write>(&$self, writer: &mut W) -> Result<usize, Error> {
// Get the byte length of the content, make sure it's not too
// large and write it out.
let tls_serialized_len = $self.tls_serialized_len();
let byte_length = tls_serialized_len - $len_len;

let max_len = <$size>::MAX as usize;
debug_assert!(
byte_length <= max_len,
"Vector length can't be encoded in the vector length a {} >= {}",
byte_length,
max_len
);
if byte_length > max_len {
return Err(Error::InvalidVectorLength);
}
let (tls_serialized_len, byte_length) = $self.get_content_lengths()?;

let mut written = (byte_length as $size).tls_serialize(writer)?;
let mut written = <$size as Serialize>::tls_serialize(&(byte_length as $size), writer)?;

// Now serialize the elements
for e in $self.as_slice().iter() {
written += e.tls_serialize(writer)?;
}

debug_assert_eq!(
written, tls_serialized_len,
"{} bytes should have been serialized but {} were written",
tls_serialized_len, written
);
if written != tls_serialized_len {
return Err(Error::EncodingError(format!(
"{} bytes should have been serialized but {} were written",
tls_serialized_len, written
)));
}
$self.assert_written_bytes(tls_serialized_len, written)?;
Ok(written)
}
};
Expand All @@ -173,6 +151,23 @@ macro_rules! impl_byte_serialize {
fn serialize_bytes<W: Write>(&$self, writer: &mut W) -> Result<usize, Error> {
// Get the byte length of the content, make sure it's not too
// large and write it out.
let (tls_serialized_len, byte_length) = $self.get_content_lengths()?;

let mut written = <$size as Serialize>::tls_serialize(&(byte_length as $size), writer)?;

// Now serialize the elements
written += writer.write($self.as_slice())?;

$self.assert_written_bytes(tls_serialized_len, written)?;
Ok(written)
}
};
}

macro_rules! impl_serialize_common {
($self:ident, $size:ty, $name:ident, $len_len:literal $(,#[$std_enabled:meta])?) => {
$(#[$std_enabled])?
fn get_content_lengths(&$self) -> Result<(usize, usize), Error> {
let tls_serialized_len = $self.tls_serialized_len();
let byte_length = tls_serialized_len - $len_len;

Expand All @@ -186,12 +181,11 @@ macro_rules! impl_byte_serialize {
if byte_length > max_len {
return Err(Error::InvalidVectorLength);
}
Ok((tls_serialized_len, byte_length))
}

let mut written = (byte_length as $size).tls_serialize(writer)?;

// Now serialize the elements
written += writer.write($self.as_slice())?;

$(#[$std_enabled])?
fn assert_written_bytes(&$self, tls_serialized_len: usize, written: usize) -> Result<(), Error> {
debug_assert_eq!(
written, tls_serialized_len,
"{} bytes should have been serialized but {} were written",
Expand All @@ -203,7 +197,28 @@ macro_rules! impl_byte_serialize {
tls_serialized_len, written
)));
}
Ok(written)
Ok(())
}
};
}

macro_rules! impl_serialize_bytes_bytes {
($self:ident, $size:ty, $name:ident, $len_len:literal) => {
fn serialize_bytes_bytes(&$self) -> Result<Vec<u8>, Error> {
let (tls_serialized_len, byte_length) = $self.get_content_lengths()?;

let mut vec = Vec::<u8>::with_capacity(tls_serialized_len);
let length_vec = <$size as SerializeBytes>::tls_serialize(&(byte_length as $size))?;
let mut written = length_vec.len();
vec.extend_from_slice(&length_vec);

let bytes = $self.as_slice();
vec.extend_from_slice(bytes);
written += bytes.len();

$self.assert_written_bytes(tls_serialized_len, written)?;

Ok(vec)
}
};
}
Expand Down Expand Up @@ -295,6 +310,12 @@ macro_rules! impl_tls_vec_codec_bytes {
Self::deserialize_bytes_bytes(bytes)
}
}

impl SerializeBytes for $name {
fn tls_serialize(&self) -> Result<Vec<u8>, Error> {
self.serialize_bytes_bytes()
}
}
};
}

Expand Down Expand Up @@ -791,6 +812,7 @@ macro_rules! impl_secret_tls_vec {
impl_tls_vec_codec_generic!($size, $name, $len_len, Zeroize);

impl<T: Serialize + Zeroize> $name<T> {
impl_serialize_common!(self, $size, $name, $len_len, #[cfg(feature = "std")]);
impl_serialize!(self, $size, $name, $len_len);
}

Expand Down Expand Up @@ -827,6 +849,7 @@ macro_rules! impl_public_tls_vec {
impl_tls_vec_codec_generic!($size, $name, $len_len);

impl<T: Serialize> $name<T> {
impl_serialize_common!(self, $size, $name, $len_len, #[cfg(feature = "std")]);
impl_serialize!(self, $size, $name, $len_len);
}

Expand All @@ -850,7 +873,9 @@ macro_rules! impl_tls_byte_vec {

impl $name {
// This implements serialize and size for all versions
impl_serialize_common!(self, $size, $name, $len_len);
impl_byte_serialize!(self, $size, $name, $len_len);
impl_serialize_bytes_bytes!(self, $size, $name, $len_len);
impl_byte_size!(self, $size, $name, $len_len);
impl_byte_deserialize!(self, $size, $name, $len_len);
}
Expand Down Expand Up @@ -887,6 +912,7 @@ macro_rules! impl_tls_byte_slice {
}

impl<'a> $name<'a> {
impl_serialize_common!(self, $size, $name, $len_len, #[cfg(feature = "std")]);
impl_byte_serialize!(self, $size, $name, $len_len);
impl_byte_size!(self, $size, $name, $len_len);
}
Expand Down Expand Up @@ -942,6 +968,7 @@ macro_rules! impl_tls_slice {
}

impl<'a, T: Serialize> $name<'a, T> {
impl_serialize_common!(self, $size, $name, $len_len, #[cfg(feature = "std")]);
impl_serialize!(self, $size, $name, $len_len);
}

Expand Down
29 changes: 28 additions & 1 deletion tls_codec/tests/encode_bytes.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use tls_codec::SerializeBytes;
use tls_codec::{SerializeBytes, TlsByteVecU16, TlsByteVecU32, TlsByteVecU8};

#[test]
fn serialize_primitives() {
Expand Down Expand Up @@ -40,3 +40,30 @@ fn serialize_var_len_boundaries() {
let serialized = v.tls_serialize().expect("Error encoding vector");
assert_eq!(&serialized[0..5], &[0x80, 0, 0x40, 0, 99]);
}

#[test]
fn serialize_tls_byte_vec_u8() {
let byte_vec = TlsByteVecU8::from_slice(&[1, 2, 3]);
let actual_result = byte_vec
.tls_serialize()
.expect("Error encoding byte vector");
assert_eq!(actual_result, vec![3, 1, 2, 3]);
}

#[test]
fn serialize_tls_byte_vec_u16() {
let byte_vec = TlsByteVecU16::from_slice(&[1, 2, 3]);
let actual_result = byte_vec
.tls_serialize()
.expect("Error encoding byte vector");
assert_eq!(actual_result, vec![0, 3, 1, 2, 3]);
}

#[test]
fn serialize_tls_byte_vec_u32() {
let byte_vec = TlsByteVecU32::from_slice(&[1, 2, 3]);
let actual_result = byte_vec
.tls_serialize()
.expect("Error encoding byte vector");
assert_eq!(actual_result, vec![0, 0, 0, 3, 1, 2, 3]);
}

0 comments on commit 90a4324

Please sign in to comment.