Skip to content

generic_view and reshape_generic do not work together #1533

@fumagall

Description

@fumagall

Hello,

I was playing around with the Dim Operations and wanted to implement a matrix multiplication, where instead of a vector of shape (dim x 1) and matrix of shape (dim x dim) one provides an vector of shape ((dim * dim + dim) x 1), basically storing the matrix and the vector in a single extended vector.

The code works with Dyn but not with Const::<N>, although all trait bounds are fulfilled in the matrix multiplication function. For some reason I get the following Error when using Const::<N>:

error[E0271]: type mismatch resolving `<ArrayStorage<f64, 6, 1> as RawStorage<f64, Const<6>>>::CStride == Const<4>`
  --> tests/bug_resize_partial_view_copy.rs:63:15
   |
63 |     let out = matrix_multipication(dim, dim.mul(dim), &extended_vector);
   |               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ expected `4`, found `6`
   |
   = note: expected struct `Const<4>`
              found struct `Const<6>`

With the following code:

use nalgebra::{allocator::Allocator, constraint::{DimEq, ShapeConstraint}, Const, DefaultAllocator, Dim, DimAdd, DimMul, DimProd, DimSum, Dyn, OMatrix, OVector, RawStorage, ReshapableStorage, Storage, ViewStorage, U1};
use rand::{thread_rng, Rng};
use rand_distr::Uniform;


fn matrix_multipication<'b, D, D2, D3>(dim: D, dim2: D2, input: &'b OVector<f64, D3>)  -> OMatrix<f64, D, U1> where 
D: Dim + DimMul<D>,
D2: Dim + DimAdd<D>,
D3: Dim,
ShapeConstraint: DimEq<D2, DimProd<D, D>>,
ShapeConstraint: DimEq<D3, DimSum<D2, D>>,
DefaultAllocator: Allocator<D3> + Allocator<D>,
ViewStorage<'b, f64, D2, U1, <<DefaultAllocator as Allocator<D3>>::Buffer<f64> as RawStorage<f64, D3>>::RStride, <<DefaultAllocator as Allocator<D3>>::Buffer<f64> as RawStorage<f64, D3>>::CStride>: ReshapableStorage<f64, D2, U1, D, D>,
<ViewStorage<'b, f64, D2, U1, <<DefaultAllocator as Allocator<D3>>::Buffer<f64> as RawStorage<f64, D3>>::RStride, <<DefaultAllocator as Allocator<D3>>::Buffer<f64> as RawStorage<f64, D3>>::CStride> as ReshapableStorage<f64, D2, U1, D, D>>::Output: Storage<f64, D, D>

{
    let matrix = input.generic_view((0, 0), (dim2, U1))
                                                                  .reshape_generic(dim, dim);
    let vector = input.generic_view((dim2.value(), 0), (dim, U1));
    (matrix * vector).clone()
}

#[test] // with Dyn it works
fn test_dyn() { 
    let elements = 10;
    let dim = Dyn(elements); 
    
    // create random matrix of shape (dim x dim) and vector of shape (dim x 1)
    let rng = thread_rng();
    let mut rng = rng.sample_iter(Uniform::new(0., 1.0));
    let matrix = OMatrix::from_iterator_generic(dim, dim, (&mut rng).take(elements*elements));
    let vector = OMatrix::from_iterator_generic(dim, U1, rng.take(elements));
    
    // save matrix and vector in extended vector of shape ((dim*dim+dim) x dim)
    let mut extended_vector = OVector::zeros_generic(dim.mul(dim).add(dim), U1);
    
    // copy the random vectors
    extended_vector.generic_view_mut((0, 0), (dim.mul(dim), U1)).copy_from(&(matrix.clone()).reshape_generic(dim.mul(dim), U1));
    extended_vector.generic_view_mut((dim.mul(dim).value(), 0), (dim, U1)).copy_from(&vector);
    
    // perform matrix multiplication using the extended vector
    let out = matrix_multipication(dim, dim.mul(dim), &extended_vector);
    assert_eq!(out, matrix*vector)
}

Cargo raises the error when including the following:

#[test] //does not compile
fn test_const() {
    let dim = Const::<2>; 
    
    // create random matrix of shape (dim x dim) and vector of shape (dim x 1)
    let rng = thread_rng();
    let mut rng = rng.sample_iter(Uniform::new(0., 1.0));
    let matrix = OMatrix::from_iterator_generic(dim, dim, (&mut rng).take(4));
    let vector = OMatrix::from_iterator_generic(dim, U1, rng.take(2));
    
    // save matrix and vector in extended vector of shape ((dim*dim+dim) x dim)
    let mut extended_vector = OVector::zeros_generic(dim.mul(dim).add(dim), U1);
    
    // copy the random vectors
    extended_vector.generic_view_mut((0, 0), (dim.mul(dim), U1)).copy_from(&(matrix.clone()).reshape_generic(dim.mul(dim), U1));
    extended_vector.generic_view_mut((dim.mul(dim).value(), 0), (dim, U1)).copy_from(&vector);
    
    // perform matrix multiplication using the extended vector
    let out = matrix_multipication(dim, dim.mul(dim), &extended_vector); // <-- raises error
    assert_eq!(out, matrix*vector)
}

I tried to remove the DimEq without luck and the numerical test appears to proof that the code is working in the Dyn-case. Tomorrow I will try switching the order of reshape and view, I have the feeling that reshape ignores the effectively smaller matrix resulting from view.

I hope there is not obvious mistake here :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions