Skip to content

Commit

Permalink
Use Threshold type in concrete::Policy::Thresh
Browse files Browse the repository at this point in the history
Use the `Threshold` type in `policy::concrete::Policy::Thresh` to help
maintain invariants on n and k.
  • Loading branch information
tcharding committed Oct 9, 2023
1 parent 1542226 commit d81faa3
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 79 deletions.
4 changes: 2 additions & 2 deletions src/iter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl<'a, Pk: MiniscriptKey> TreeLike for &'a policy::Concrete<Pk> {
| Ripemd160(_) | Hash160(_) => Tree::Nullary,
And(ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()),
Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| p.as_ref()).collect()),
Thresh(_, ref subs) => Tree::Nary(subs.iter().map(Arc::as_ref).collect()),
Thresh(thresh) => Tree::Nary(thresh.iter().map(Arc::as_ref).collect()),
}
}
}
Expand All @@ -90,7 +90,7 @@ impl<'a, Pk: MiniscriptKey> TreeLike for Arc<policy::Concrete<Pk>> {
| Ripemd160(_) | Hash160(_) => Tree::Nullary,
And(ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()),
Or(ref v) => Tree::Nary(v.iter().map(|(_, p)| Arc::clone(p)).collect()),
Thresh(_, ref subs) => Tree::Nary(subs.iter().map(Arc::clone).collect()),
Thresh(thresh) => Tree::Nary(thresh.iter().map(Arc::clone).collect()),
}
}
}
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ mod prelude {
rc, slice,
string::{String, ToString},
sync,
vec::Vec,
vec::{self, Vec},
};
#[cfg(any(feature = "std", test))]
pub use std::{
Expand All @@ -873,7 +873,7 @@ mod prelude {
string::{String, ToString},
sync,
sync::Mutex,
vec::Vec,
vec::{self, Vec},
};

#[cfg(all(not(feature = "std"), not(test)))]
Expand Down
42 changes: 23 additions & 19 deletions src/policy/compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -920,8 +920,9 @@ where
compile_binary!(&mut l_comp[3], &mut r_comp[2], [lw, rw], Terminal::OrI);
compile_binary!(&mut r_comp[3], &mut l_comp[2], [rw, lw], Terminal::OrI);
}
Concrete::Thresh(k, ref subs) => {
let n = subs.len();
Concrete::Thresh(ref thresh) => {
let k = thresh.k();
let n = thresh.n();
let k_over_n = k as f64 / n as f64;

let mut sub_ast = Vec::with_capacity(n);
Expand All @@ -931,7 +932,7 @@ where
let mut best_ws = Vec::with_capacity(n);

let mut min_value = (0, f64::INFINITY);
for (i, ast) in subs.iter().enumerate() {
for (i, ast) in thresh.iter().enumerate() {
let sp = sat_prob * k_over_n;
//Expressions must be dissatisfiable
let dp = Some(dissat_prob.unwrap_or(0 as f64) + (1.0 - k_over_n) * sat_prob);
Expand All @@ -949,7 +950,7 @@ where
}
sub_ext_data.push(best_es[min_value.0].0);
sub_ast.push(Arc::clone(&best_es[min_value.0].1.ms));
for (i, _ast) in subs.iter().enumerate() {
for (i, _ast) in thresh.iter().enumerate() {
if i != min_value.0 {
sub_ext_data.push(best_ws[i].0);
sub_ast.push(Arc::clone(&best_ws[i].1.ms));
Expand All @@ -966,7 +967,7 @@ where
insert_wrap!(ast_ext);
}

let key_vec: Vec<Pk> = subs
let key_vec: Vec<Pk> = thresh
.iter()
.filter_map(|s| {
if let Concrete::Key(ref pk) = s.as_ref() {
Expand All @@ -978,16 +979,16 @@ where
.collect();

match Ctx::sig_type() {
SigType::Schnorr if key_vec.len() == subs.len() => {
SigType::Schnorr if key_vec.len() == thresh.n() => {
insert_wrap!(AstElemExt::terminal(Terminal::MultiA(k, key_vec)))
}
SigType::Ecdsa
if key_vec.len() == subs.len() && subs.len() <= MAX_PUBKEYS_PER_MULTISIG =>
if key_vec.len() == thresh.n() && thresh.n() <= MAX_PUBKEYS_PER_MULTISIG =>
{
insert_wrap!(AstElemExt::terminal(Terminal::Multi(k, key_vec)))
}
_ if k == subs.len() => {
let mut it = subs.iter();
_ if k == thresh.n() => {
let mut it = thresh.iter();
let mut policy = it.next().expect("No sub policy in thresh() ?").clone();
policy =
it.fold(policy, |acc, pol| Concrete::And(vec![acc, pol.clone()]).into());
Expand Down Expand Up @@ -1157,6 +1158,7 @@ mod tests {
use super::*;
use crate::miniscript::{Legacy, Segwitv0, Tap};
use crate::policy::Liftable;
use crate::threshold::Threshold;
use crate::{script_num_size, ToPublicKey};

type SPolicy = Concrete<String>;
Expand Down Expand Up @@ -1301,19 +1303,19 @@ mod tests {
let policy: BPolicy = Concrete::Or(vec![
(
127,
Arc::new(Concrete::Thresh(
Arc::new(Concrete::Thresh(Threshold::new_unchecked(
3,
key_pol[0..5].iter().map(|p| (p.clone()).into()).collect(),
)),
))),
),
(
1,
Arc::new(Concrete::And(vec![
Arc::new(Concrete::Older(Sequence::from_height(10000))),
Arc::new(Concrete::Thresh(
Arc::new(Concrete::Thresh(Threshold::new_unchecked(
2,
key_pol[5..8].iter().map(|p| (p.clone()).into()).collect(),
)),
))),
])),
),
]);
Expand Down Expand Up @@ -1430,7 +1432,7 @@ mod tests {
.iter()
.map(|pubkey| Arc::new(Concrete::Key(*pubkey)))
.collect();
let big_thresh = Concrete::Thresh(*k, pubkeys);
let big_thresh = Concrete::Thresh(Threshold::new_unchecked(*k, pubkeys));
let big_thresh_ms: SegwitMiniScript = big_thresh.compile().unwrap();
if *k == 21 {
// N * (PUSH + pubkey + CHECKSIGVERIFY)
Expand Down Expand Up @@ -1466,8 +1468,8 @@ mod tests {
.collect();

let thresh_res: Result<SegwitMiniScript, _> = Concrete::Or(vec![
(1, Arc::new(Concrete::Thresh(keys_a.len(), keys_a))),
(1, Arc::new(Concrete::Thresh(keys_b.len(), keys_b))),
(1, Arc::new(Concrete::Thresh(Threshold::new_unchecked(keys_a.len(), keys_a)))),
(1, Arc::new(Concrete::Thresh(Threshold::new_unchecked(keys_b.len(), keys_b)))),
])
.compile();
let script_size = thresh_res.clone().and_then(|m| Ok(m.script_size()));
Expand All @@ -1484,7 +1486,8 @@ mod tests {
.iter()
.map(|pubkey| Arc::new(Concrete::Key(*pubkey)))
.collect();
let thresh_res: Result<SegwitMiniScript, _> = Concrete::Thresh(keys.len(), keys).compile();
let thresh_res: Result<SegwitMiniScript, _> =
Concrete::Thresh(Threshold::new_unchecked(keys.len(), keys)).compile();
let n_elements = thresh_res
.clone()
.and_then(|m| Ok(m.max_satisfaction_witness_elements()));
Expand All @@ -1505,7 +1508,7 @@ mod tests {
.map(|pubkey| Arc::new(Concrete::Key(*pubkey)))
.collect();
let thresh_res: Result<SegwitMiniScript, _> =
Concrete::Thresh(keys.len() - 1, keys).compile();
Concrete::Thresh(Threshold::new_unchecked(keys.len() - 1, keys)).compile();
let ops_count = thresh_res.clone().and_then(|m| Ok(m.ext.ops.op_count()));
assert_eq!(
thresh_res,
Expand All @@ -1519,7 +1522,8 @@ mod tests {
.iter()
.map(|pubkey| Arc::new(Concrete::Key(*pubkey)))
.collect();
let thresh_res = Concrete::Thresh(keys.len() - 1, keys).compile::<Legacy>();
let thresh_res =
Concrete::Thresh(Threshold::new_unchecked(keys.len() - 1, keys)).compile::<Legacy>();
let ops_count = thresh_res.clone().and_then(|m| Ok(m.ext.ops.op_count()));
assert_eq!(
thresh_res,
Expand Down
Loading

0 comments on commit d81faa3

Please sign in to comment.