Skip to content

Commit 652fd1d

Browse files
authored
Group trait overhaul (#52)
* Decouple element from `Group` * Change `SUITE_ID` to `u16` and rework `get_context_string()` * Rework scalar de-serialization * Rename `Group` methods - `random_nonzero_scalar` -> `random_scalar` - `scalar_as_bytes` -> `serialize_scalar` - `scalar_invert` -> `invert_scalar` * Rework element de-serialization * Rename and remove `Group` methods `to_arr` -> `serialize_elem` `base_point` -> `base_elem` `is_identity` -> removed `identity` -> `identity_elem` `zero_scalar` -> hidden behind `cfg(test)` * Sort `Group` methods * Rework `expand_message_xmd` and remove utility * Improve P256 `hash_to_scalar`
1 parent e767543 commit 652fd1d

13 files changed

+665
-710
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ p256 = [
2424
"once_cell",
2525
"p256_",
2626
]
27-
ristretto255 = []
27+
ristretto255 = ["generic-array/more_lengths"]
2828
ristretto255_fiat_u32 = ["curve25519-dalek/fiat_u32_backend", "ristretto255"]
2929
ristretto255_fiat_u64 = ["curve25519-dalek/fiat_u64_backend", "ristretto255"]
3030
ristretto255_simd = ["curve25519-dalek/simd_backend", "ristretto255"]

src/error.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ pub enum Error {
3434
ProofVerificationError,
3535
/// Encountered insufficient bytes when attempting to deserialize
3636
SizeError,
37-
/// Encountered a zero scalar
38-
ZeroScalarError,
37+
/// Encountered an invalid scalar
38+
ScalarError,
3939
}
4040

4141
#[cfg(feature = "std")]

src/group/expand.rs

+56-54
Original file line numberDiff line numberDiff line change
@@ -5,74 +5,82 @@
55
// License, Version 2.0 found in the LICENSE-APACHE file in the root directory
66
// of this source tree.
77

8-
use core::ops::Add;
8+
use core::convert::TryFrom;
99

10-
use digest::core_api::BlockSizeUser;
10+
use digest::core_api::{Block, BlockSizeUser};
1111
use digest::{Digest, FixedOutputReset};
12-
use generic_array::sequence::Concat;
13-
use generic_array::typenum::{Unsigned, U1, U2};
12+
use generic_array::typenum::{IsLess, NonZero, Unsigned, U65536};
1413
use generic_array::{ArrayLength, GenericArray};
1514

16-
use crate::util::i2osp;
1715
use crate::{Error, Result};
1816

19-
// Computes ceil(x / y)
20-
fn div_ceil(x: usize, y: usize) -> usize {
21-
let additive = (x % y != 0) as usize;
22-
x / y + additive
23-
}
24-
2517
fn xor<L: ArrayLength<u8>>(x: GenericArray<u8, L>, y: GenericArray<u8, L>) -> GenericArray<u8, L> {
2618
x.into_iter().zip(y).map(|(x1, x2)| x1 ^ x2).collect()
2719
}
2820

2921
/// Corresponds to the expand_message_xmd() function defined in
3022
/// <https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.txt>
31-
pub fn expand_message_xmd<
32-
'a,
33-
H: BlockSizeUser + Digest + FixedOutputReset,
34-
L: ArrayLength<u8>,
35-
M: IntoIterator<Item = &'a [u8]>,
36-
D: ArrayLength<u8> + Add<U1>,
37-
>(
38-
msg: M,
39-
dst: GenericArray<u8, D>,
23+
pub fn expand_message_xmd<H: BlockSizeUser + Digest + FixedOutputReset, L: ArrayLength<u8>>(
24+
msg: &[&[u8]],
25+
dst: &[u8],
4026
) -> Result<GenericArray<u8, L>>
4127
where
42-
<D as Add<U1>>::Output: ArrayLength<u8>,
28+
// Constraint set by `expand_message_xmd`:
29+
// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-6
30+
L: NonZero + IsLess<U65536>,
4331
{
44-
let digest_len = H::OutputSize::USIZE;
45-
let ell = div_ceil(L::USIZE, digest_len);
46-
if ell > 255 {
32+
// DST, a byte string of at most 255 bytes.
33+
let dst_len = u8::try_from(dst.len()).map_err(|_| Error::HashToCurveError)?;
34+
35+
// b_in_bytes, b / 8 for b the output size of H in bits.
36+
let b_in_bytes = H::OutputSize::to_usize();
37+
38+
// Constraint set by `expand_message_xmd`:
39+
// https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-13.html#section-5.4.1-4
40+
if b_in_bytes > H::BlockSize::USIZE {
4741
return Err(Error::HashToCurveError);
4842
}
49-
let dst_prime = dst.concat(i2osp::<U1>(D::USIZE)?);
50-
let z_pad = i2osp::<H::BlockSize>(0)?;
51-
let l_i_b_str = i2osp::<U2>(L::USIZE)?;
5243

53-
let mut h = H::new();
44+
// ell = ceil(len_in_bytes / b_in_bytes)
45+
// ABORT if ell > 255
46+
let ell = u8::try_from((L::USIZE + b_in_bytes - 1) / b_in_bytes)
47+
.map_err(|_| Error::HashToCurveError)?;
48+
49+
let mut hash = H::new();
5450

51+
// b_0 = H(msg_prime)
5552
// msg_prime = Z_pad || msg || l_i_b_str || I2OSP(0, 1) || DST_prime
56-
Digest::update(&mut h, z_pad);
57-
for bytes in msg {
58-
Digest::update(&mut h, bytes)
53+
// Z_pad = I2OSP(0, s_in_bytes)
54+
// s_in_bytes, the input block size of H, measured in bytes
55+
Digest::update(&mut hash, Block::<H>::default());
56+
for msg in msg {
57+
Digest::update(&mut hash, msg);
5958
}
60-
Digest::update(&mut h, l_i_b_str);
61-
Digest::update(&mut h, i2osp::<U1>(0)?);
62-
Digest::update(&mut h, &dst_prime);
59+
// l_i_b_str = I2OSP(len_in_bytes, 2)
60+
Digest::update(&mut hash, L::U16.to_be_bytes());
61+
Digest::update(&mut hash, [0]);
62+
// DST_prime = DST || I2OSP(len(DST), 1)
63+
Digest::update(&mut hash, dst);
64+
Digest::update(&mut hash, [dst_len]);
65+
let b_0 = hash.finalize_reset();
6366

64-
// b[0]
65-
let b_0 = h.finalize_reset();
6667
let mut b_i = GenericArray::default();
6768

6869
let mut uniform_bytes = GenericArray::default();
6970

70-
for (i, chunk) in (1..(ell + 1)).zip(uniform_bytes.chunks_mut(digest_len)) {
71-
Digest::update(&mut h, xor(b_0.clone(), b_i.clone()));
72-
Digest::update(&mut h, i2osp::<U1>(i)?);
73-
Digest::update(&mut h, &dst_prime);
74-
b_i = h.finalize_reset();
75-
chunk.copy_from_slice(&b_i[..digest_len.min(chunk.len())]);
71+
// b_1 = H(b_0 || I2OSP(1, 1) || DST_prime)
72+
// for i in (2, ..., ell):
73+
for (i, chunk) in (1..(ell + 1)).zip(uniform_bytes.chunks_mut(b_in_bytes)) {
74+
// b_i = H(strxor(b_0, b_(i - 1)) || I2OSP(i, 1) || DST_prime)
75+
Digest::update(&mut hash, xor(b_0.clone(), b_i.clone()));
76+
Digest::update(&mut hash, [i]);
77+
// DST_prime = DST || I2OSP(len(DST), 1)
78+
Digest::update(&mut hash, dst);
79+
Digest::update(&mut hash, [dst_len]);
80+
b_i = hash.finalize_reset();
81+
// uniform_bytes = b_1 || ... || b_ell
82+
// return substr(uniform_bytes, 0, len_in_bytes)
83+
chunk.copy_from_slice(&b_i[..b_in_bytes.min(chunk.len())]);
7684
}
7785

7886
Ok(uniform_bytes)
@@ -81,7 +89,6 @@ where
8189
#[cfg(test)]
8290
mod tests {
8391
use generic_array::typenum::{U128, U32};
84-
use generic_array::GenericArray;
8592

8693
struct Params {
8794
msg: &'static str,
@@ -91,6 +98,8 @@ mod tests {
9198

9299
#[test]
93100
fn test_expand_message_xmd() {
101+
const DST: [u8; 27] = *b"QUUX-V01-CS02-with-expander";
102+
94103
// Test vectors taken from Section K.1 of https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-10.txt
95104
let test_vectors: alloc::vec::Vec<Params> = alloc::vec![
96105
Params {
@@ -190,20 +199,13 @@ mod tests {
190199
378fba044a31f5cb44583a892f5969dcd73b3fa128816e",
191200
},
192201
];
193-
let dst = GenericArray::from(*b"QUUX-V01-CS02-with-expander");
194202

195203
for tv in test_vectors {
196204
let uniform_bytes = match tv.len_in_bytes {
197-
32 => super::expand_message_xmd::<sha2::Sha256, U32, _, _>(
198-
Some(tv.msg.as_bytes()),
199-
dst,
200-
)
201-
.map(|bytes| bytes.to_vec()),
202-
128 => super::expand_message_xmd::<sha2::Sha256, U128, _, _>(
203-
Some(tv.msg.as_bytes()),
204-
dst,
205-
)
206-
.map(|bytes| bytes.to_vec()),
205+
32 => super::expand_message_xmd::<sha2::Sha256, U32>(&[tv.msg.as_bytes()], &DST)
206+
.map(|bytes| bytes.to_vec()),
207+
128 => super::expand_message_xmd::<sha2::Sha256, U128>(&[tv.msg.as_bytes()], &DST)
208+
.map(|bytes| bytes.to_vec()),
207209
_ => unimplemented!(),
208210
}
209211
.unwrap();

src/group/mod.rs

+48-87
Original file line numberDiff line numberDiff line change
@@ -18,47 +18,36 @@ use core::ops::{Add, Mul, Sub};
1818

1919
use digest::core_api::BlockSizeUser;
2020
use digest::{Digest, FixedOutputReset};
21-
use generic_array::typenum::U1;
2221
use generic_array::{ArrayLength, GenericArray};
2322
use rand_core::{CryptoRng, RngCore};
23+
#[cfg(feature = "ristretto255")]
24+
pub use ristretto::Ristretto255;
2425
use subtle::ConstantTimeEq;
2526
use zeroize::Zeroize;
2627

27-
use crate::{Error, Result};
28+
use crate::voprf::Mode;
29+
use crate::Result;
30+
31+
pub(crate) const STR_HASH_TO_SCALAR: [u8; 13] = *b"HashToScalar-";
32+
pub(crate) const STR_HASH_TO_GROUP: [u8; 12] = *b"HashToGroup-";
2833

2934
/// A prime-order subgroup of a base field (EC, prime-order field ...). This
3035
/// subgroup is noted additively — as in the draft RFC — in this trait.
31-
pub trait Group:
32-
Copy
33-
+ Sized
34-
+ ConstantTimeEq
35-
+ for<'a> Mul<&'a <Self as Group>::Scalar, Output = Self>
36-
+ for<'a> Add<&'a Self, Output = Self>
37-
{
36+
pub trait Group {
3837
/// The ciphersuite identifier as dictated by
3938
/// <https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-05.txt>
40-
const SUITE_ID: usize;
39+
const SUITE_ID: u16;
4140

42-
/// transforms a password and domain separation tag (DST) into a curve point
43-
fn hash_to_curve<H: BlockSizeUser + Digest + FixedOutputReset, D: ArrayLength<u8> + Add<U1>>(
44-
msg: &[u8],
45-
dst: GenericArray<u8, D>,
46-
) -> Result<Self>
47-
where
48-
<D as Add<U1>>::Output: ArrayLength<u8>;
41+
/// The type of group elements
42+
type Elem: Copy
43+
+ Sized
44+
+ ConstantTimeEq
45+
+ Zeroize
46+
+ for<'a> Mul<&'a Self::Scalar, Output = Self::Elem>
47+
+ for<'a> Add<&'a Self::Elem, Output = Self::Elem>;
4948

50-
/// Hashes a slice of pseudo-random bytes to a scalar
51-
fn hash_to_scalar<
52-
'a,
53-
H: BlockSizeUser + Digest + FixedOutputReset,
54-
D: ArrayLength<u8> + Add<U1>,
55-
I: IntoIterator<Item = &'a [u8]>,
56-
>(
57-
input: I,
58-
dst: GenericArray<u8, D>,
59-
) -> Result<Self::Scalar>
60-
where
61-
<D as Add<U1>>::Output: ArrayLength<u8>;
49+
/// The byte length necessary to represent group elements
50+
type ElemLen: ArrayLength<u8> + 'static;
6251

6352
/// The type of base field scalars
6453
type Scalar: Zeroize
@@ -67,79 +56,51 @@ pub trait Group:
6756
+ for<'a> Add<&'a Self::Scalar, Output = Self::Scalar>
6857
+ for<'a> Sub<&'a Self::Scalar, Output = Self::Scalar>
6958
+ for<'a> Mul<&'a Self::Scalar, Output = Self::Scalar>;
59+
7060
/// The byte length necessary to represent scalars
7161
type ScalarLen: ArrayLength<u8> + 'static;
7262

73-
/// Return a scalar from its fixed-length bytes representation, without
74-
/// checking if the scalar is zero.
75-
fn from_scalar_slice_unchecked(
76-
scalar_bits: &GenericArray<u8, Self::ScalarLen>,
77-
) -> Result<Self::Scalar>;
63+
/// transforms a password and domain separation tag (DST) into a curve point
64+
fn hash_to_curve<H: BlockSizeUser + Digest + FixedOutputReset>(
65+
msg: &[&[u8]],
66+
mode: Mode,
67+
) -> Result<Self::Elem>;
7868

79-
/// Return a scalar from its fixed-length bytes representation. If the
80-
/// scalar is zero, then return an error.
81-
fn from_scalar_slice<'a>(
82-
scalar_bits: impl Into<&'a GenericArray<u8, Self::ScalarLen>>,
83-
) -> Result<Self::Scalar> {
84-
let scalar = Self::from_scalar_slice_unchecked(scalar_bits.into())?;
85-
if scalar.ct_eq(&Self::scalar_zero()).into() {
86-
return Err(Error::ZeroScalarError);
87-
}
88-
Ok(scalar)
89-
}
69+
/// Hashes a slice of pseudo-random bytes to a scalar
70+
fn hash_to_scalar<H: BlockSizeUser + Digest + FixedOutputReset>(
71+
input: &[&[u8]],
72+
mode: Mode,
73+
) -> Result<Self::Scalar>;
9074

91-
/// picks a scalar at random
92-
fn random_nonzero_scalar<R: RngCore + CryptoRng>(rng: &mut R) -> Self::Scalar;
93-
/// Serializes a scalar to bytes
94-
fn scalar_as_bytes(scalar: Self::Scalar) -> GenericArray<u8, Self::ScalarLen>;
95-
/// The multiplicative inverse of this scalar
96-
fn scalar_invert(scalar: &Self::Scalar) -> Self::Scalar;
75+
/// Get the base point for the group
76+
fn base_elem() -> Self::Elem;
9777

98-
/// The byte length necessary to represent group elements
99-
type ElemLen: ArrayLength<u8> + 'static;
78+
/// Returns the identity group element
79+
fn identity_elem() -> Self::Elem;
10080

101-
/// Return an element from its fixed-length bytes representation. This is
102-
/// the unchecked version, which does not check for deserializing the
103-
/// identity element
104-
fn from_element_slice_unchecked(element_bits: &GenericArray<u8, Self::ElemLen>)
105-
-> Result<Self>;
81+
/// Serializes the `self` group element
82+
fn serialize_elem(elem: Self::Elem) -> GenericArray<u8, Self::ElemLen>;
10683

10784
/// Return an element from its fixed-length bytes representation. If the
10885
/// element is the identity element, return an error.
109-
fn from_element_slice<'a>(
110-
element_bits: impl Into<&'a GenericArray<u8, Self::ElemLen>>,
111-
) -> Result<Self> {
112-
let elem = Self::from_element_slice_unchecked(element_bits.into())?;
113-
114-
if Self::ct_eq(&elem, &<Self as Group>::identity()).into() {
115-
// found the identity element
116-
return Err(Error::PointError);
117-
}
118-
119-
Ok(elem)
120-
}
121-
122-
/// Serializes the `self` group element
123-
fn to_arr(&self) -> GenericArray<u8, Self::ElemLen>;
86+
fn deserialize_elem(element_bits: &GenericArray<u8, Self::ElemLen>) -> Result<Self::Elem>;
12487

125-
/// Get the base point for the group
126-
fn base_point() -> Self;
127-
128-
/// Returns if the group element is equal to the identity (1)
129-
fn is_identity(&self) -> bool {
130-
self.ct_eq(&<Self as Group>::identity()).into()
131-
}
88+
/// picks a scalar at random
89+
fn random_scalar<R: RngCore + CryptoRng>(rng: &mut R) -> Self::Scalar;
13290

133-
/// Returns the identity group element
134-
fn identity() -> Self;
91+
/// The multiplicative inverse of this scalar
92+
fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar;
13593

13694
/// Returns the scalar representing zero
137-
fn scalar_zero() -> Self::Scalar;
95+
#[cfg(test)]
96+
fn zero_scalar() -> Self::Scalar;
97+
98+
/// Serializes a scalar to bytes
99+
fn serialize_scalar(scalar: Self::Scalar) -> GenericArray<u8, Self::ScalarLen>;
138100

139-
/// Set the contents of self to the identity value
140-
fn zeroize(&mut self) {
141-
*self = <Self as Group>::identity();
142-
}
101+
/// Return a scalar from its fixed-length bytes representation. If the
102+
/// scalar is zero or invalid, then return an error.
103+
fn deserialize_scalar(scalar_bits: &GenericArray<u8, Self::ScalarLen>) -> Result<Self::Scalar>;
143104
}
144105

145106
#[cfg(test)]

0 commit comments

Comments
 (0)