Skip to content

Lanczos Solver which=SA,SM,LA,LM argument #2628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 56 commits into
base: branch-25.08
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
6327b2f
init commit lanczos solver which argument
aamijar Apr 8, 2025
43a5ca5
update pytest
aamijar Apr 8, 2025
d1ff91e
pre-commit
aamijar Apr 8, 2025
b227f83
pre-commit
aamijar Apr 8, 2025
0077dd0
pre-commit
aamijar Apr 8, 2025
e78dda9
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar Apr 8, 2025
62f7b52
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar Apr 8, 2025
6022c6b
remove comments
aamijar Apr 8, 2025
eda2b99
set default argument for eigen_solver_config_t
aamijar Apr 8, 2025
c4d543e
refactor
aamijar Apr 8, 2025
550c911
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar Apr 10, 2025
168c3cd
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar Apr 14, 2025
671fdc1
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar Apr 15, 2025
b3a09b0
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar Apr 19, 2025
906ba4c
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar Apr 22, 2025
52bec42
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar Apr 29, 2025
8caa1bf
resolving pr comments
aamijar Apr 29, 2025
38ae702
resolving pr comments
aamijar Apr 30, 2025
1da9a59
refactor which argument in lanczos_solve_ritz
aamijar Apr 30, 2025
12f61fd
pre-commit
aamijar Apr 30, 2025
6fbba71
doxygen format
aamijar May 6, 2025
a0422cc
update docs
aamijar May 6, 2025
5e2e3df
update docs
aamijar May 6, 2025
0ce7c31
fix enum
aamijar May 6, 2025
266db4c
add gtests
aamijar May 7, 2025
3d5823c
doxygen format
aamijar May 7, 2025
18cad5b
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar May 7, 2025
3cd12bb
update gtests
aamijar May 7, 2025
b601cbc
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar May 7, 2025
ad1acff
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar May 8, 2025
2c4b014
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar May 14, 2025
2df2819
rename namespace
aamijar May 14, 2025
005de7b
Merge branch 'branch-25.06' into lanczos-solver-which-argument
cjnolet May 14, 2025
1c7601c
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar May 14, 2025
6c5d370
test ci
aamijar May 15, 2025
ab7ed27
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar May 15, 2025
cb5ad1d
test ci
aamijar May 15, 2025
88cdf91
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar May 15, 2025
de18fc2
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar May 15, 2025
2409b8e
remove unused
aamijar May 15, 2025
88b575a
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar May 15, 2025
b34dc3e
update pytest
aamijar May 21, 2025
56ccc7c
Merge branch 'branch-25.06' into lanczos-solver-which-argument
aamijar May 21, 2025
eead758
test ci
aamijar May 21, 2025
d221be2
test ci
aamijar May 21, 2025
b2166f6
test ci
aamijar May 29, 2025
5eb1d3d
Merge branch 'branch-25.08' into lanczos-solver-which-argument
aamijar May 29, 2025
d49830b
test ci
aamijar May 29, 2025
eb5fa14
test ci
aamijar May 29, 2025
3cd8f21
test ci
aamijar May 31, 2025
594b825
test ci shift-inverse
aamijar Jun 3, 2025
020e4f4
remove SM
aamijar Jun 4, 2025
fbd7509
remove SM
aamijar Jun 4, 2025
932a36e
remove SM
aamijar Jun 4, 2025
09d8028
remove SM
aamijar Jun 11, 2025
f1e0b34
Merge branch 'branch-25.08' into lanczos-solver-which-argument
aamijar Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 103 additions & 19 deletions cpp/include/raft/sparse/solver/detail/lanczos.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
#include <raft/linalg/transpose.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/diagonal.cuh>
#include <raft/matrix/gather.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/matrix/slice.cuh>
#include <raft/matrix/triangular.cuh>
Expand All @@ -62,6 +63,7 @@
#include <raft/util/cudart_utils.hpp>

#include <cuda.h>
#include <thrust/sort.h>

#include <cublasLt.h>
#include <curand.h>
Expand Down Expand Up @@ -1506,10 +1508,15 @@ void lanczos_solve_ritz(
raft::device_matrix_view<ValueTypeT, uint32_t, raft::row_major> beta,
std::optional<raft::device_vector_view<ValueTypeT, uint32_t>> beta_k,
IndexTypeT k,
int which,
LANCZOS_WHICH which,
int ncv,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors,
raft::device_vector_view<ValueTypeT> eigenvalues)
raft::device_vector_view<ValueTypeT> eigenvalues,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major>& eigenvectors_k,
raft::device_vector_view<ValueTypeT, uint32_t>& eigenvalues_k,
raft::device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>& eigenvectors_k_slice,
raft::device_vector_view<ValueTypeT> sm_eigenvalues,
raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> sm_eigenvectors)
{
auto stream = resource::get_cuda_stream(handle);

Expand Down Expand Up @@ -1542,6 +1549,75 @@ void lanczos_solve_ritz(
triangular_matrix.data_handle(), ncv, ncv);

raft::linalg::eig_dc(handle, triangular_matrix_view, eigenvectors, eigenvalues);

IndexTypeT nEigVecs = k;

auto indices = raft::make_device_vector<int>(handle, ncv);
auto selected_indices = raft::make_device_vector<int>(handle, nEigVecs);

if (which == LANCZOS_WHICH::SA) {
eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
eigenvectors.data_handle(), ncv, nEigVecs);
eigenvalues_k =
raft::make_device_vector_view<ValueTypeT, uint32_t>(eigenvalues.data_handle(), nEigVecs);
eigenvectors_k_slice = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
eigenvectors.data_handle(), ncv, nEigVecs);
} else if (which == LANCZOS_WHICH::LA) {
eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
eigenvectors.data_handle() + (ncv - nEigVecs) * ncv, ncv, nEigVecs);
eigenvalues_k = raft::make_device_vector_view<ValueTypeT, uint32_t>(
eigenvalues.data_handle() + (ncv - nEigVecs), nEigVecs);
eigenvectors_k_slice = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
eigenvectors.data_handle() + (ncv - nEigVecs) * ncv, ncv, nEigVecs);
} else if (which == LANCZOS_WHICH::SM || which == LANCZOS_WHICH::LM) {
thrust::sequence(thrust::device, indices.data_handle(), indices.data_handle() + ncv, 0);

// Sort indices by absolute eigenvalues (magnitude) using a custom comparator
thrust::sort(thrust::device,
indices.data_handle(),
indices.data_handle() + ncv,
[eigenvalues = eigenvalues.data_handle()] __device__(int a, int b) {
return fabsf(eigenvalues[a]) < fabsf(eigenvalues[b]);
});

if (which == LANCZOS_WHICH::SM) {
// Take the first nEigVecs indices (smallest magnitude)
raft::copy(selected_indices.data_handle(), indices.data_handle(), nEigVecs, stream);
} else if (which == LANCZOS_WHICH::LM) {
// Take the last nEigVecs indices (largest magnitude)
raft::copy(
selected_indices.data_handle(), indices.data_handle() + (ncv - nEigVecs), nEigVecs, stream);
}

// Re-sort these indices by algebraic value to maintain algebraic ordering
thrust::sort(thrust::device,
selected_indices.data_handle(),
selected_indices.data_handle() + nEigVecs,
[eigenvalues = eigenvalues.data_handle()] __device__(int a, int b) {
return eigenvalues[a] < eigenvalues[b];
});
raft::matrix::gather(
handle,
raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::row_major>(
eigenvalues.data_handle(), ncv, 1),
raft::make_device_vector_view<const int, uint32_t>(selected_indices.data_handle(), nEigVecs),
raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::row_major>(
sm_eigenvalues.data_handle(), nEigVecs, 1));
raft::matrix::gather(
handle,
raft::make_device_matrix_view<const ValueTypeT, uint32_t, raft::row_major>(
eigenvectors.data_handle(), ncv, ncv),
raft::make_device_vector_view<const int, uint32_t>(selected_indices.data_handle(), nEigVecs),
raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::row_major>(
sm_eigenvectors.data_handle(), nEigVecs, ncv));

eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
sm_eigenvectors.data_handle(), ncv, nEigVecs);
eigenvalues_k =
raft::make_device_vector_view<ValueTypeT, uint32_t>(sm_eigenvalues.data_handle(), nEigVecs);
eigenvectors_k_slice = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
sm_eigenvectors.data_handle(), ncv, nEigVecs);
}
}

template <typename IndexTypeT, typename ValueTypeT>
Expand Down Expand Up @@ -1698,6 +1774,7 @@ auto lanczos_smallest(
int maxIter,
int restartIter,
ValueTypeT tol,
LANCZOS_WHICH which,
ValueTypeT* eigVals_dev,
ValueTypeT* eigVecs_dev,
ValueTypeT* v0,
Expand Down Expand Up @@ -1759,20 +1836,28 @@ auto lanczos_smallest(
raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(handle, ncv, ncv);
auto eigenvalues = raft::make_device_vector<ValueTypeT, uint32_t>(handle, ncv);

raft::device_matrix_view<ValueTypeT, uint32_t, raft::col_major> eigenvectors_k;
raft::device_vector_view<ValueTypeT, uint32_t> eigenvalues_k;
raft::device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major> eigenvectors_k_slice;

auto sm_eigenvalues = raft::make_device_vector<ValueTypeT>(handle, nEigVecs);
auto sm_eigenvectors =
raft::make_device_matrix<ValueTypeT, uint32_t, raft::col_major>(handle, ncv, nEigVecs);

lanczos_solve_ritz<IndexTypeT, ValueTypeT>(handle,
alpha.view(),
beta.view(),
std::nullopt,
nEigVecs,
0,
which,
ncv,
eigenvectors.view(),
eigenvalues.view());

auto eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
eigenvectors.data_handle(), ncv, nEigVecs);
auto eigenvalues_k =
raft::make_device_vector_view<ValueTypeT, uint32_t>(eigenvalues.data_handle(), nEigVecs);
eigenvalues.view(),
eigenvectors_k,
eigenvalues_k,
eigenvectors_k_slice,
sm_eigenvalues.view(),
sm_eigenvectors.view());

auto ritz_eigenvectors =
raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(eigVecs_dev, n, nEigVecs);
Expand All @@ -1784,9 +1869,6 @@ auto lanczos_smallest(

auto s = raft::make_device_vector<ValueTypeT>(handle, nEigVecs);

auto eigenvectors_k_slice =
raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
eigenvectors.data_handle(), ncv, nEigVecs);
auto S_matrix = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
s.data_handle(), 1, nEigVecs);

Expand Down Expand Up @@ -2021,12 +2103,15 @@ auto lanczos_smallest(
beta.view(),
beta_k.view(),
nEigVecs,
0,
which,
ncv,
eigenvectors.view(),
eigenvalues.view());
auto eigenvectors_k = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
eigenvectors.data_handle(), ncv, nEigVecs);
eigenvalues.view(),
eigenvectors_k,
eigenvalues_k,
eigenvectors_k_slice,
sm_eigenvalues.view(),
sm_eigenvectors.view());

auto ritz_eigenvectors = raft::make_device_matrix_view<ValueTypeT, uint32_t, raft::col_major>(
eigVecs_dev, n, nEigVecs);
Expand All @@ -2036,9 +2121,6 @@ auto lanczos_smallest(
raft::linalg::gemm<ValueTypeT, uint32_t, raft::col_major, raft::col_major, raft::col_major>(
handle, V_T, eigenvectors_k, ritz_eigenvectors);

auto eigenvectors_k_slice =
raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
eigenvectors.data_handle(), ncv, nEigVecs);
auto S_matrix = raft::make_device_matrix_view<ValueTypeT, IndexTypeT, raft::col_major>(
s.data_handle(), 1, nEigVecs);

Expand Down Expand Up @@ -2089,6 +2171,7 @@ auto lanczos_compute_smallest_eigenvectors(
config.max_iterations,
config.ncv,
config.tolerance,
config.which,
eigenvalues.data_handle(),
eigenvectors.data_handle(),
v0->data_handle(),
Expand All @@ -2105,6 +2188,7 @@ auto lanczos_compute_smallest_eigenvectors(
config.max_iterations,
config.ncv,
config.tolerance,
config.which,
eigenvalues.data_handle(),
eigenvectors.data_handle(),
temp_v0.data_handle(),
Expand Down
50 changes: 45 additions & 5 deletions cpp/include/raft/sparse/solver/lanczos_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,57 @@

namespace raft::sparse::solver {

/**
* @enum LANCZOS_WHICH
* @brief Enumeration specifying which eigenvalues to compute in the Lanczos algorithm
*/
enum LANCZOS_WHICH {
/** @brief LA: Largest (algebraic) eigenvalues */
LA,
/** @brief LM: Largest (in magnitude) eigenvalues */
LM,
/** @brief SA: Smallest (algebraic) eigenvalues */
SA,
/** @brief SM: Smallest (in magnitude) eigenvalues */
SM
};

/**
* @brief Configuration parameters for the Lanczos eigensolver
*
* This structure encapsulates all configuration parameters needed to run the
* Lanczos algorithm for computing eigenvalues and eigenvectors of large sparse matrices.
*
* @tparam ValueTypeT Data type for values (float or double)
*/
template <typename ValueTypeT>
struct lanczos_solver_config {
/** The number of eigenvalues and eigenvectors to compute. Must be 1 <= k < n.*/
/** @brief The number of eigenvalues and eigenvectors to compute
* @note Must be 1 <= n_components < n, where n is the matrix dimension
*/
int n_components;
/** Maximum number of iteration. */

/** @brief Maximum number of iterations allowed for the algorithm to converge */
int max_iterations;
/** The number of Lanczos vectors generated. Must be k + 1 < ncv < n. */

/** @brief The number of Lanczos vectors to generate
* @note Must satisfy n_components + 1 < ncv < n, where n is the matrix dimension
*/
int ncv;
/** Tolerance for residuals ``||Ax - wx||`` */

/** @brief Convergence tolerance for residuals
* @note Used to determine when to stop iteration based on ||Ax - wx|| < tolerance
*/
ValueTypeT tolerance;
/** random seed */

/** @brief Specifies which eigenvalues to compute in the Lanczos algorithm
* @see LANCZOS_WHICH for possible values (SA, LA, SM, LM)
*/
LANCZOS_WHICH which;

/** @brief Random seed for initialization of the algorithm
* @note Controls reproducibility of results
*/
uint64_t seed;
};

Expand Down
11 changes: 9 additions & 2 deletions cpp/include/raft/spectral/eigen_solvers.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ struct eigen_solver_config_t {
1234567}; // CAVEAT: this default value is now common to all instances of using seed in
// Lanczos; was not the case before: there were places where a default seed = 123456
// was used; this may trigger slightly different # solver iterations

raft::sparse::solver::LANCZOS_WHICH which{raft::sparse::solver::LANCZOS_WHICH::SA};
};

template <typename index_type_t, typename value_type_t, typename size_type_t = index_type_t>
Expand Down Expand Up @@ -79,8 +81,13 @@ struct lanczos_solver_t {
RAFT_EXPECTS(eigVals != nullptr, "Null eigVals buffer.");
RAFT_EXPECTS(eigVecs != nullptr, "Null eigVecs buffer.");

auto lanczos_config = raft::sparse::solver::lanczos_solver_config<value_type_t>{
config_.n_eigVecs, config_.maxIter, config_.restartIter, config_.tol, config_.seed};
auto lanczos_config =
raft::sparse::solver::lanczos_solver_config<value_type_t>{config_.n_eigVecs,
config_.maxIter,
config_.restartIter,
config_.tol,
config_.which,
config_.seed};
auto v0_opt = std::optional<raft::device_vector_view<value_type_t, uint32_t, raft::row_major>>{
std::nullopt};
auto input_structure = input.structure_view();
Expand Down
Loading