Skip to content

Commit

Permalink
Use the decoupled num-traits.
Browse files Browse the repository at this point in the history
  • Loading branch information
murisi committed May 16, 2024
1 parent b5e87a0 commit 3f8d33b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 30 deletions.
23 changes: 16 additions & 7 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions masp_primitives/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ memuse = "0.2.1"

# - Checked arithmetic
num-traits = "0.2.14"
num-traits-decoupled = { package = "num-traits", version = "0.2.19", git = "https://github.com/heliaxdev/num-traits", rev = "e3e712fc4ecbf9d95399b0b98ee9f3d6b9973e38" }

# - Secret management
subtle = "2.2.3"
Expand Down
52 changes: 29 additions & 23 deletions masp_primitives/src/transaction/components/amount.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use borsh::schema::Fields;
use borsh::schema::{Declaration, Definition};
use borsh::BorshSchema;
use borsh::{BorshDeserialize, BorshSerialize};
use num_traits::{CheckedAdd, CheckedMul, CheckedNeg, CheckedSub, One};
use num_traits_decoupled::{CheckedAdd, CheckedMul, CheckedNeg, CheckedSub, One};
use std::cmp::Ordering;
use std::collections::btree_map::Keys;
use std::collections::btree_map::{IntoIter, Iter};
Expand Down Expand Up @@ -422,7 +422,7 @@ where
+ Copy
+ Default
+ PartialOrd
+ CheckedMul,
+ CheckedMul<Output = Value>,
{
fn mul_assign(&mut self, rhs: Value) {
*self = self.clone() * rhs;
Expand All @@ -440,18 +440,19 @@ where
+ Default
+ PartialOrd
+ CheckedMul,
<Value as CheckedMul>::Output : Default + BorshSerialize + BorshDeserialize + Eq,
{
type Output = ValueSum<Unit, Value>;
type Output = ValueSum<Unit, <Value as CheckedMul>::Output>;

fn mul(self, rhs: Value) -> Self::Output {
let mut comps = BTreeMap::new();
for (atype, amount) in self.0.iter() {
comps.insert(
atype.clone(),
amount.checked_mul(&rhs).expect("overflow detected"),
amount.checked_mul(rhs).expect("overflow detected"),
);
}
comps.retain(|_, v| *v != Value::default());
comps.retain(|_, v| *v != <Value as CheckedMul>::Output::default());
ValueSum(comps)
}
}
Expand All @@ -466,7 +467,7 @@ where
+ Copy
+ Default
+ PartialOrd
+ CheckedAdd,
+ CheckedAdd<Output = Value>,
{
fn add_assign(&mut self, rhs: &ValueSum<Unit, Value>) {
*self = self.clone() + rhs;
Expand All @@ -483,7 +484,7 @@ where
+ Copy
+ Default
+ PartialOrd
+ CheckedAdd,
+ CheckedAdd<Output = Value>,
{
fn add_assign(&mut self, rhs: ValueSum<Unit, Value>) {
*self += &rhs
Expand All @@ -500,7 +501,7 @@ where
+ Copy
+ Default
+ PartialOrd
+ CheckedAdd,
+ CheckedAdd<Output = Value>,
{
type Output = ValueSum<Unit, Value>;

Expand All @@ -519,7 +520,7 @@ where
+ Copy
+ Default
+ PartialOrd
+ CheckedAdd,
+ CheckedAdd<Output = Value>,
{
type Output = ValueSum<Unit, Value>;

Expand All @@ -528,7 +529,7 @@ where
}
}

impl<Unit, Value> CheckedAdd for ValueSum<Unit, Value>
impl<Unit, Value> CheckedAdd for &ValueSum<Unit, Value>
where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Value: BorshSerialize
Expand All @@ -538,12 +539,14 @@ where
+ Copy
+ Default
+ PartialOrd
+ CheckedAdd,
+ CheckedAdd<Output = Value>,
{
fn checked_add(&self, v: &Self) -> Option<Self> {
type Output = ValueSum<Unit, Value>;

fn checked_add(self, v: Self) -> Option<Self::Output> {
let mut comps = self.0.clone();
for (atype, amount) in v.components() {
comps.insert(atype.clone(), self.get(atype).checked_add(amount)?);
comps.insert(atype.clone(), self.get(atype).checked_add(*amount)?);
}
comps.retain(|_, v| *v != Value::default());
Some(ValueSum(comps))
Expand All @@ -560,7 +563,7 @@ where
+ Copy
+ Default
+ PartialOrd
+ CheckedSub,
+ CheckedSub<Output = Value>,
{
fn sub_assign(&mut self, rhs: &ValueSum<Unit, Value>) {
*self = self.clone() - rhs
Expand All @@ -577,7 +580,7 @@ where
+ Copy
+ Default
+ PartialOrd
+ CheckedSub,
+ CheckedSub<Output = Value>,
{
fn sub_assign(&mut self, rhs: ValueSum<Unit, Value>) {
*self -= &rhs
Expand All @@ -595,8 +598,9 @@ where
+ Default
+ PartialOrd
+ CheckedNeg,
<Value as CheckedNeg>::Output: BorshSerialize + BorshDeserialize + Eq + Default,
{
type Output = ValueSum<Unit, Value>;
type Output = ValueSum<Unit, <Value as CheckedNeg>::Output>;

fn neg(mut self) -> Self::Output {
let mut comps = BTreeMap::new();
Expand All @@ -606,15 +610,15 @@ where
amount.checked_neg().expect("overflow detected"),
);
}
comps.retain(|_, v| *v != Value::default());
comps.retain(|_, v| *v != <Value as CheckedNeg>::Output::default());
ValueSum(comps)
}
}

impl<Unit, Value> Sub<&ValueSum<Unit, Value>> for ValueSum<Unit, Value>
where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Value: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + CheckedSub,
Value: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + CheckedSub<Output = Value>,
{
type Output = ValueSum<Unit, Value>;

Expand All @@ -626,7 +630,7 @@ where
impl<Unit, Value> Sub<ValueSum<Unit, Value>> for ValueSum<Unit, Value>
where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Value: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + CheckedSub,
Value: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + CheckedSub<Output = Value>,
{
type Output = ValueSum<Unit, Value>;

Expand All @@ -635,15 +639,17 @@ where
}
}

impl<Unit, Value> CheckedSub for ValueSum<Unit, Value>
impl<Unit, Value> CheckedSub for &ValueSum<Unit, Value>
where
Unit: Hash + Ord + BorshSerialize + BorshDeserialize + Clone,
Value: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + CheckedSub,
Value: BorshSerialize + BorshDeserialize + PartialEq + Eq + Copy + Default + CheckedSub<Output = Value>,
{
fn checked_sub(&self, v: &Self) -> Option<Self> {
type Output = ValueSum<Unit, Value>;

fn checked_sub(self, v: Self) -> Option<Self::Output> {
let mut comps = self.0.clone();
for (atype, amount) in v.components() {
comps.insert(atype.clone(), self.get(atype).checked_sub(amount)?);
comps.insert(atype.clone(), self.get(atype).checked_sub(*amount)?);
}
comps.retain(|_, v| *v != Value::default());
Some(ValueSum(comps))
Expand Down

0 comments on commit 3f8d33b

Please sign in to comment.