Skip to content

Commit 142714c

Browse files
committed
Change SUITE_ID to u16 and rework get_context_string()
1 parent 0f16437 commit 142714c

File tree

4 files changed

+40
-34
lines changed

4 files changed

+40
-34
lines changed

src/group/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use crate::{Error, Result};
3333
pub trait Group {
3434
/// The ciphersuite identifier as dictated by
3535
/// <https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-05.txt>
36-
const SUITE_ID: usize;
36+
const SUITE_ID: u16;
3737

3838
/// The type of group elements
3939
type Elem: Copy

src/group/p256.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ pub type L = U48;
4242

4343
#[cfg(feature = "p256")]
4444
impl Group for NistP256 {
45-
const SUITE_ID: usize = 0x0003;
45+
const SUITE_ID: u16 = 0x0003;
4646

4747
type Elem = ProjectivePoint;
4848

src/group/ristretto.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ pub struct Ristretto255;
2727
// `cfg` here is only needed because of a bug in Rust's crate feature documentation. See: https://github.com/rust-lang/rust/issues/83428
2828
#[cfg(feature = "ristretto255")]
2929
impl Group for Ristretto255 {
30-
const SUITE_ID: usize = 0x0001;
30+
const SUITE_ID: u16 = 0x0001;
3131

3232
type Elem = RistrettoPoint;
3333

src/voprf.rs

+37-31
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use derive_where::DeriveWhere;
1717
use digest::core_api::BlockSizeUser;
1818
use digest::{Digest, FixedOutputReset, Output};
1919
use generic_array::sequence::Concat;
20-
use generic_array::typenum::{U1, U11, U2, U20};
20+
use generic_array::typenum::{U11, U2, U20};
2121
use generic_array::GenericArray;
2222
use rand_core::{CryptoRng, RngCore};
2323
use subtle::ConstantTimeEq;
@@ -42,8 +42,17 @@ static STR_VOPRF: [u8; 8] = *b"VOPRF08-";
4242
/// Determines the mode of operation (either base mode or verifiable mode)
4343
#[derive(Clone, Copy)]
4444
enum Mode {
45-
Base = 0,
46-
Verifiable = 1,
45+
Base,
46+
Verifiable,
47+
}
48+
49+
impl Mode {
50+
fn to_u8(self) -> u8 {
51+
match self {
52+
Mode::Base => 0,
53+
Mode::Verifiable => 1,
54+
}
55+
}
4756
}
4857

4958
////////////////////////////
@@ -418,7 +427,7 @@ impl<G: Group, H: BlockSizeUser + Digest + FixedOutputReset> NonVerifiableServer
418427
/// Corresponds to DeriveKeyPair() function from the VOPRF specification.
419428
pub fn new_from_seed(seed: &[u8]) -> Result<Self> {
420429
let dst =
421-
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(Mode::Base)?);
430+
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(Mode::Base));
422431
let sk = G::hash_to_scalar::<H, _, _>(Some(seed), dst)?;
423432
Ok(Self {
424433
sk,
@@ -443,11 +452,11 @@ impl<G: Group, H: BlockSizeUser + Digest + FixedOutputReset> NonVerifiableServer
443452
chain!(
444453
context,
445454
STR_CONTEXT => |x| Some(x.as_ref()),
446-
get_context_string::<G>(Mode::Base)? => |x| Some(x.as_slice()),
455+
get_context_string::<G>(Mode::Base) => |x| Some(x.as_slice()),
447456
Serialize::<U2>::from(metadata.unwrap_or_default())?,
448457
);
449458
let dst =
450-
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(Mode::Base)?);
459+
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(Mode::Base));
451460
let m = G::hash_to_scalar::<H, _, _>(context, dst)?;
452461
let t = self.sk + &m;
453462
let evaluation_element = blinded_element.value * &G::scalar_invert(&t);
@@ -486,7 +495,7 @@ impl<G: Group, H: BlockSizeUser + Digest + FixedOutputReset> VerifiableServer<G,
486495
/// Corresponds to DeriveKeyPair() function from the VOPRF specification.
487496
pub fn new_from_seed(seed: &[u8]) -> Result<Self> {
488497
let dst = GenericArray::from(STR_HASH_TO_SCALAR)
489-
.concat(get_context_string::<G>(Mode::Verifiable)?);
498+
.concat(get_context_string::<G>(Mode::Verifiable));
490499
let sk = G::hash_to_scalar::<H, _, _>(Some(seed), dst)?;
491500
let pk = G::base_point() * &sk;
492501
Ok(Self {
@@ -581,11 +590,11 @@ impl<G: Group, H: BlockSizeUser + Digest + FixedOutputReset> VerifiableServer<G,
581590
) -> Result<VerifiableServerBatchEvaluatePrepareResult<'a, G, H, I>> {
582591
chain!(context,
583592
STR_CONTEXT => |x| Some(x.as_ref()),
584-
get_context_string::<G>(Mode::Verifiable)? => |x| Some(x.as_slice()),
593+
get_context_string::<G>(Mode::Verifiable) => |x| Some(x.as_slice()),
585594
Serialize::<U2>::from(metadata.unwrap_or_default())?,
586595
);
587596
let dst = GenericArray::from(STR_HASH_TO_SCALAR)
588-
.concat(get_context_string::<G>(Mode::Verifiable)?);
597+
.concat(get_context_string::<G>(Mode::Verifiable));
589598
let m = G::hash_to_scalar::<H, _, _>(context, dst)?;
590599
let t = self.sk + &m;
591600
let evaluation_elements = blinded_elements
@@ -847,7 +856,7 @@ fn deterministic_blind_unchecked<G: Group, H: BlockSizeUser + Digest + FixedOutp
847856
blind: &G::Scalar,
848857
mode: Mode,
849858
) -> Result<G::Elem> {
850-
let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::<G>(mode)?);
859+
let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::<G>(mode));
851860
let hashed_point = G::hash_to_curve::<H, _>(input, dst)?;
852861
Ok(hashed_point * blind)
853862
}
@@ -884,12 +893,12 @@ where
884893
{
885894
chain!(context,
886895
STR_CONTEXT => |x| Some(x.as_ref()),
887-
get_context_string::<G>(Mode::Verifiable)? => |x| Some(x.as_slice()),
896+
get_context_string::<G>(Mode::Verifiable) => |x| Some(x.as_slice()),
888897
Serialize::<U2>::from(info)?,
889898
);
890899

891900
let dst =
892-
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(Mode::Verifiable)?);
901+
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(Mode::Verifiable));
893902
let m = G::hash_to_scalar::<H, _, _>(context, dst)?;
894903

895904
let g = G::base_point();
@@ -933,7 +942,7 @@ fn generate_proof<
933942
let t3 = m * &r;
934943

935944
let challenge_dst =
936-
GenericArray::from(STR_CHALLENGE).concat(get_context_string::<G>(Mode::Verifiable)?);
945+
GenericArray::from(STR_CHALLENGE).concat(get_context_string::<G>(Mode::Verifiable));
937946
chain!(
938947
h2_input,
939948
Serialize::<U2, _>::from_owned(G::to_arr(b))?,
@@ -945,7 +954,7 @@ fn generate_proof<
945954
);
946955

947956
let hash_to_scalar_dst =
948-
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(Mode::Verifiable)?);
957+
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(Mode::Verifiable));
949958

950959
let c_scalar = G::hash_to_scalar::<H, _, _>(h2_input, hash_to_scalar_dst)?;
951960
let s_scalar = r - &(c_scalar * &k);
@@ -970,7 +979,7 @@ fn verify_proof<G: Group, H: BlockSizeUser + Digest + FixedOutputReset>(
970979
let t3 = (m * &proof.s_scalar) + &(z * &proof.c_scalar);
971980

972981
let challenge_dst =
973-
GenericArray::from(STR_CHALLENGE).concat(get_context_string::<G>(Mode::Verifiable)?);
982+
GenericArray::from(STR_CHALLENGE).concat(get_context_string::<G>(Mode::Verifiable));
974983
chain!(
975984
h2_input,
976985
Serialize::<U2, _>::from_owned(G::to_arr(b))?,
@@ -982,7 +991,7 @@ fn verify_proof<G: Group, H: BlockSizeUser + Digest + FixedOutputReset>(
982991
);
983992

984993
let hash_to_scalar_dst =
985-
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(Mode::Verifiable)?);
994+
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(Mode::Verifiable));
986995
let c = G::hash_to_scalar::<H, _, _>(h2_input, hash_to_scalar_dst)?;
987996

988997
match c.ct_eq(&proof.c_scalar).into() {
@@ -1007,7 +1016,7 @@ fn finalize_after_unblind<
10071016
info: &'a [u8],
10081017
mode: Mode,
10091018
) -> Result<FinalizeAfterUnblindResult<G, H, I, IE>> {
1010-
let finalize_dst = GenericArray::from(STR_FINALIZE).concat(get_context_string::<G>(mode)?);
1019+
let finalize_dst = GenericArray::from(STR_FINALIZE).concat(get_context_string::<G>(mode));
10111020

10121021
Ok(inputs_and_unblinded_elements
10131022
// To make a return type possible, we have to convert to a `fn` pointer,
@@ -1038,9 +1047,9 @@ fn compute_composites<G: Group, H: BlockSizeUser + Digest + FixedOutputReset>(
10381047
return Err(Error::MismatchedLengthsForCompositeInputs);
10391048
}
10401049

1041-
let seed_dst = GenericArray::from(STR_SEED).concat(get_context_string::<G>(Mode::Verifiable)?);
1050+
let seed_dst = GenericArray::from(STR_SEED).concat(get_context_string::<G>(Mode::Verifiable));
10421051
let composite_dst =
1043-
GenericArray::from(STR_COMPOSITE).concat(get_context_string::<G>(Mode::Verifiable)?);
1052+
GenericArray::from(STR_COMPOSITE).concat(get_context_string::<G>(Mode::Verifiable));
10441053

10451054
chain!(
10461055
h1_input,
@@ -1063,7 +1072,7 @@ fn compute_composites<G: Group, H: BlockSizeUser + Digest + FixedOutputReset>(
10631072
Serialize::<U2, _>::from_owned(composite_dst)?,
10641073
);
10651074
let dst = GenericArray::from(STR_HASH_TO_SCALAR)
1066-
.concat(get_context_string::<G>(Mode::Verifiable)?);
1075+
.concat(get_context_string::<G>(Mode::Verifiable));
10671076
let di = G::hash_to_scalar::<H, _, _>(h2_input, dst)?;
10681077
m = c.value * &di + &m;
10691078
z = match k_option {
@@ -1082,10 +1091,10 @@ fn compute_composites<G: Group, H: BlockSizeUser + Digest + FixedOutputReset>(
10821091

10831092
/// Generates the contextString parameter as defined in
10841093
/// <https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html>
1085-
fn get_context_string<G: Group>(mode: Mode) -> Result<GenericArray<u8, U11>> {
1086-
Ok(GenericArray::from(STR_VOPRF)
1087-
.concat(i2osp::<U1>(mode as usize)?)
1088-
.concat(i2osp::<U2>(G::SUITE_ID)?))
1094+
fn get_context_string<G: Group>(mode: Mode) -> GenericArray<u8, U11> {
1095+
GenericArray::from(STR_VOPRF)
1096+
.concat([mode.to_u8()].into())
1097+
.concat(G::SUITE_ID.to_be_bytes().into())
10891098
}
10901099

10911100
///////////
@@ -1113,18 +1122,16 @@ mod tests {
11131122
info: &[u8],
11141123
mode: Mode,
11151124
) -> Output<H> {
1116-
let dst =
1117-
GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::<G>(mode).unwrap());
1125+
let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::<G>(mode));
11181126
let point = G::hash_to_curve::<H, _>(input, dst).unwrap();
11191127

11201128
chain!(context,
11211129
STR_CONTEXT => |x| Some(x.as_ref()),
1122-
get_context_string::<G>(mode).unwrap() => |x| Some(x.as_slice()),
1130+
get_context_string::<G>(mode) => |x| Some(x.as_slice()),
11231131
Serialize::<U2>::from(info).unwrap(),
11241132
);
11251133

1126-
let dst =
1127-
GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(mode).unwrap());
1134+
let dst = GenericArray::from(STR_HASH_TO_SCALAR).concat(get_context_string::<G>(mode));
11281135
let m = G::hash_to_scalar::<H, _, _>(context, dst).unwrap();
11291136

11301137
let res = point * &G::scalar_invert(&(key + &m));
@@ -1315,8 +1322,7 @@ mod tests {
13151322
)
13161323
.unwrap();
13171324

1318-
let dst = GenericArray::from(STR_HASH_TO_GROUP)
1319-
.concat(get_context_string::<G>(Mode::Base).unwrap());
1325+
let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(get_context_string::<G>(Mode::Base));
13201326
let point = G::hash_to_curve::<H, _>(&input, dst).unwrap();
13211327
let res2 = finalize_after_unblind::<G, H, _, _>(
13221328
Some((input.as_ref(), point)).into_iter(),

0 commit comments

Comments
 (0)