-
-
Notifications
You must be signed in to change notification settings - Fork 516
Description
Hello,
I was playing around with the Dim Operations and wanted to implement a matrix multiplication, where instead of a vector of shape (dim x 1) and matrix of shape (dim x dim) one provides an vector of shape ((dim * dim + dim) x 1), basically storing the matrix and the vector in a single extended vector.
The code works with Dyn
but not with Const::<N>
, although all trait bounds are fulfilled in the matrix multiplication function. For some reason I get the following Error when using Const::<N>
:
error[E0271]: type mismatch resolving `<ArrayStorage<f64, 6, 1> as RawStorage<f64, Const<6>>>::CStride == Const<4>`
--> tests/bug_resize_partial_view_copy.rs:63:15
|
63 | let out = matrix_multipication(dim, dim.mul(dim), &extended_vector);
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `4`, found `6`
|
= note: expected struct `Const<4>`
found struct `Const<6>`
With the following code:
use nalgebra::{allocator::Allocator, constraint::{DimEq, ShapeConstraint}, Const, DefaultAllocator, Dim, DimAdd, DimMul, DimProd, DimSum, Dyn, OMatrix, OVector, RawStorage, ReshapableStorage, Storage, ViewStorage, U1};
use rand::{thread_rng, Rng};
use rand_distr::Uniform;
fn matrix_multipication<'b, D, D2, D3>(dim: D, dim2: D2, input: &'b OVector<f64, D3>) -> OMatrix<f64, D, U1> where
D: Dim + DimMul<D>,
D2: Dim + DimAdd<D>,
D3: Dim,
ShapeConstraint: DimEq<D2, DimProd<D, D>>,
ShapeConstraint: DimEq<D3, DimSum<D2, D>>,
DefaultAllocator: Allocator<D3> + Allocator<D>,
ViewStorage<'b, f64, D2, U1, <<DefaultAllocator as Allocator<D3>>::Buffer<f64> as RawStorage<f64, D3>>::RStride, <<DefaultAllocator as Allocator<D3>>::Buffer<f64> as RawStorage<f64, D3>>::CStride>: ReshapableStorage<f64, D2, U1, D, D>,
<ViewStorage<'b, f64, D2, U1, <<DefaultAllocator as Allocator<D3>>::Buffer<f64> as RawStorage<f64, D3>>::RStride, <<DefaultAllocator as Allocator<D3>>::Buffer<f64> as RawStorage<f64, D3>>::CStride> as ReshapableStorage<f64, D2, U1, D, D>>::Output: Storage<f64, D, D>
{
let matrix = input.generic_view((0, 0), (dim2, U1))
.reshape_generic(dim, dim);
let vector = input.generic_view((dim2.value(), 0), (dim, U1));
(matrix * vector).clone()
}
#[test] // with Dyn it works
fn test_dyn() {
let elements = 10;
let dim = Dyn(elements);
// create random matrix of shape (dim x dim) and vector of shape (dim x 1)
let rng = thread_rng();
let mut rng = rng.sample_iter(Uniform::new(0., 1.0));
let matrix = OMatrix::from_iterator_generic(dim, dim, (&mut rng).take(elements*elements));
let vector = OMatrix::from_iterator_generic(dim, U1, rng.take(elements));
// save matrix and vector in extended vector of shape ((dim*dim+dim) x dim)
let mut extended_vector = OVector::zeros_generic(dim.mul(dim).add(dim), U1);
// copy the random vectors
extended_vector.generic_view_mut((0, 0), (dim.mul(dim), U1)).copy_from(&(matrix.clone()).reshape_generic(dim.mul(dim), U1));
extended_vector.generic_view_mut((dim.mul(dim).value(), 0), (dim, U1)).copy_from(&vector);
// perform matrix multiplication using the extended vector
let out = matrix_multipication(dim, dim.mul(dim), &extended_vector);
assert_eq!(out, matrix*vector)
}
Cargo raises the error when including the following:
#[test] //does not compile
fn test_const() {
let dim = Const::<2>;
// create random matrix of shape (dim x dim) and vector of shape (dim x 1)
let rng = thread_rng();
let mut rng = rng.sample_iter(Uniform::new(0., 1.0));
let matrix = OMatrix::from_iterator_generic(dim, dim, (&mut rng).take(4));
let vector = OMatrix::from_iterator_generic(dim, U1, rng.take(2));
// save matrix and vector in extended vector of shape ((dim*dim+dim) x dim)
let mut extended_vector = OVector::zeros_generic(dim.mul(dim).add(dim), U1);
// copy the random vectors
extended_vector.generic_view_mut((0, 0), (dim.mul(dim), U1)).copy_from(&(matrix.clone()).reshape_generic(dim.mul(dim), U1));
extended_vector.generic_view_mut((dim.mul(dim).value(), 0), (dim, U1)).copy_from(&vector);
// perform matrix multiplication using the extended vector
let out = matrix_multipication(dim, dim.mul(dim), &extended_vector); // <-- raises error
assert_eq!(out, matrix*vector)
}
I tried to remove the DimEq
without luck and the numerical test appears to proof that the code is working in the Dyn-case. Tomorrow I will try switching the order of reshape and view, I have the feeling that reshape ignores the effectively smaller matrix resulting from view.
I hope there is not obvious mistake here :)