Skip to content

Commit

Permalink
Determinant for refactor
Browse files Browse the repository at this point in the history
Added determinant for refactor (correctly this time (hopefully))
  • Loading branch information
zinkkkk authored and sarah-quinones committed Jan 24, 2025
1 parent 24d192f commit 3bf8ea5
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 0 deletions.
36 changes: 36 additions & 0 deletions faer/src/linalg/reductions/determinant.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use crate::{assert, get_global_parallelism};
use crate::internal_prelude::*;
use alloc::vec;
use dyn_stack::MemBuffer;

#[math]
pub fn determinant<T: ComplexField>(mat: MatRef<'_, T>) -> T::Canonical {
assert!(mat.nrows() == mat.ncols());

let par = get_global_parallelism();
let (m, n) = mat.shape();
let mut row_perm_fwd = vec![0usize; m];
let mut row_perm_bwd = vec![0usize; m];

let mut factors = mat.to_owned();
let count = linalg::lu::partial_pivoting::factor::lu_in_place(
factors.as_mat_mut(),
&mut row_perm_fwd,
&mut row_perm_bwd,
par,
MemStack::new(&mut MemBuffer::new(
linalg::lu::partial_pivoting::factor::lu_in_place_scratch::<usize, T>(m, n, par, default()),
)),
default(),
).0.transposition_count;

let mut det = one();
for i in 0..factors.nrows() {
det = mul(det, factors.as_mat_ref().read(i, i));
}
if count % 2 == 0 {
det
} else {
neg(det)
}
}
1 change: 1 addition & 0 deletions faer/src/linalg/reductions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pub mod norm_l2;
pub mod norm_l2_sqr;
pub mod norm_max;
pub mod sum;
pub mod determinant;
9 changes: 9 additions & 0 deletions faer/src/mat/matmut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,15 @@ impl<'a, T, Rows: Shape, Cols: Shape, RStride: Stride, CStride: Stride> MatMut<'
self.rb().sum()
}

#[inline]
/// see [`MatRef::determinant`]
pub fn determinant(&self) -> T::Canonical
where
T: Conjugate,
{
self.rb().determinant()
}

#[track_caller]
#[inline]
/// see [`MatRef::get`]
Expand Down
9 changes: 9 additions & 0 deletions faer/src/mat/matown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -955,6 +955,15 @@ impl<T, Rows: Shape, Cols: Shape> Mat<T, Rows, Cols> {
self.as_ref().norm_l2()
}

#[inline]
/// see [`MatRef::determinant`]
pub fn determinant(&self) -> T::Canonical
where
T: Conjugate,
{
self.as_ref().determinant()
}

#[track_caller]
#[inline]
/// see [`MatRef::get`]
Expand Down
10 changes: 10 additions & 0 deletions faer/src/mat/matref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1146,6 +1146,16 @@ impl<'a, T, Rows: Shape, Cols: Shape, RStride: Stride, CStride: Stride> MatRef<'
if try_const! { Conj::get::<T>().is_conj() } { conj(val) } else { val }
}

/// returns the determinant of `self`
#[inline]
#[math]
pub fn determinant(&self) -> T::Canonical
where
T: Conjugate,
{
linalg::reductions::determinant::determinant(self.canonical().as_dyn_stride().as_dyn())
}

/// returns references to the element at the given index, or submatrices if either `row`
/// or `col` is a range, with bound checks
///
Expand Down

0 comments on commit 3bf8ea5

Please sign in to comment.