Skip to content

Commit bd90307

Browse files
committed
Use BlockSize in spmv
1 parent 7d5e924 commit bd90307

File tree

2 files changed

+16
-24
lines changed

2 files changed

+16
-24
lines changed

cpp/dolfinx/la/MatrixCSR.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <mpi.h>
1717
#include <numeric>
1818
#include <span>
19+
#include <type_traits>
1920
#include <utility>
2021
#include <vector>
2122

@@ -788,13 +789,13 @@ void MatrixCSR<Scalar, V, W, X>::mult(la::Vector<Scalar>& x,
788789
// yi[0] += Ai[0] * xi[0]
789790
if (_bs[1] == 1)
790791
{
791-
impl::spmv<Scalar, 1>(Avalues, Arow_begin, Aoff_diag_offset, Acols, _x, _y,
792-
_bs[0], 1);
792+
impl::spmv<Scalar, BS<1>>(Avalues, Arow_begin, Aoff_diag_offset, Acols, _x,
793+
_y, _bs[0], BS<1>());
793794
}
794795
else
795796
{
796-
impl::spmv<Scalar, -1>(Avalues, Arow_begin, Aoff_diag_offset, Acols, _x, _y,
797-
_bs[0], _bs[1]);
797+
impl::spmv<Scalar, int>(Avalues, Arow_begin, Aoff_diag_offset, Acols, _x,
798+
_y, _bs[0], _bs[1]);
798799
}
799800

800801
// finalize ghost update
@@ -804,13 +805,13 @@ void MatrixCSR<Scalar, V, W, X>::mult(la::Vector<Scalar>& x,
804805
// yi[0] += Ai[1] * xi[1]
805806
if (_bs[1] == 1)
806807
{
807-
impl::spmv<Scalar, 1>(Avalues, Aoff_diag_offset, Arow_end, Acols, _x, _y,
808-
_bs[0], 1);
808+
impl::spmv<Scalar, BS<1>>(Avalues, Aoff_diag_offset, Arow_end, Acols, _x,
809+
_y, _bs[0], BS<1>());
809810
}
810811
else
811812
{
812-
impl::spmv<Scalar, -1>(Avalues, Aoff_diag_offset, Arow_end, Acols, _x, _y,
813-
_bs[0], _bs[1]);
813+
impl::spmv<Scalar, int>(Avalues, Aoff_diag_offset, Arow_end, Acols, _x, _y,
814+
_bs[0], _bs[1]);
814815
}
815816
}
816817

cpp/dolfinx/la/matrix_csr_impl.h

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#pragma once
88

9+
#include "dolfinx/common/types.h"
910
#include <numeric>
1011
#include <span>
1112
#include <utility>
@@ -222,12 +223,13 @@ void insert_nonblocked_csr(U&& data, const V& cols, const W& row_ptr,
222223
/// @param y
223224
/// @param bs0
224225
/// @param bs1
225-
template <typename T, int BS1>
226+
template <typename T, BlockSize BS1>
226227
void spmv(std::span<const T> values, std::span<const std::int64_t> row_begin,
227228
std::span<const std::int64_t> row_end,
228229
std::span<const std::int32_t> indices, std::span<const T> x,
229-
std::span<T> y, int bs0, int bs1)
230+
std::span<T> y, int bs0, BS1 _bs1)
230231
{
232+
int bs1 = block_size(_bs1);
231233
assert(row_begin.size() == row_end.size());
232234
for (int k0 = 0; k0 < bs0; ++k0)
233235
{
@@ -236,21 +238,10 @@ void spmv(std::span<const T> values, std::span<const std::int64_t> row_begin,
236238
T vi{0};
237239
for (std::int32_t j = row_begin[i]; j < row_end[i]; j++)
238240
{
239-
if constexpr (BS1 == -1)
241+
for (int k1 = 0; k1 < bs1; ++k1)
240242
{
241-
for (int k1 = 0; k1 < bs1; ++k1)
242-
{
243-
vi += values[j * bs1 * bs0 + k1 * bs0 + k0]
244-
* x[indices[j] * bs1 + k1];
245-
}
246-
}
247-
else
248-
{
249-
for (int k1 = 0; k1 < BS1; ++k1)
250-
{
251-
vi += values[j * BS1 * bs0 + k1 * bs0 + k0]
252-
* x[indices[j] * BS1 + k1];
253-
}
243+
vi += values[j * bs1 * bs0 + k1 * bs0 + k0]
244+
* x[indices[j] * bs1 + k1];
254245
}
255246
}
256247

0 commit comments

Comments
 (0)