Skip to content

Commit 17a9a59

Browse files
committed
Fix compilation issues
1 parent 11e2bc2 commit 17a9a59

File tree

1 file changed

+70
-46
lines changed

1 file changed

+70
-46
lines changed

cpp/include/raft/sparse/solver/detail/lobpcg.cuh

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,9 @@ void selectColsIf(const raft::handle_t& handle,
135135
raft::linalg::map(
136136
handle,
137137
raft::make_const_mdspan(mask),
138+
raft::make_const_mdspan(rangeVec.view()),
138139
rangeVec.view(),
139-
[] __device__(index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1; },
140-
rangeVec.view());
140+
[] __device__(index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1; });
141141
thrust::sort(rmm::exec_policy(stream),
142142
rangeVec.data_handle(),
143143
rangeVec.data_handle() + rangeVec.size(),
@@ -172,11 +172,11 @@ void truncEig(
172172
}
173173
if (eigVectorTrunc.has_value() && ncols > eigVectorTrunc->extent(1))
174174
raft::matrix::truncZeroOrigin(eigVectorin.data_handle(),
175-
n_rows,
175+
nrows,
176176
eigVectorTrunc->data_handle(),
177177
nrows,
178178
eigVectorTrunc->extent(1),
179-
stream);
179+
handle.get_stream());
180180
}
181181

182182
// C = A * B
@@ -447,7 +447,7 @@ bool eigh(const raft::handle_t& handle,
447447

448448
raft::linalg::eig_dc(handle, raft::make_const_mdspan(F.view()), Fvecs.view(), eigVals);
449449
raft::linalg::gemm(handle, Ri.view(), Fvecs.view(), eigVecs);
450-
return cho_success
450+
return cho_success;
451451
}
452452

453453
/**
@@ -604,8 +604,10 @@ void lobpcg(
604604
auto eigVectorBuffer = rmm::device_uvector<value_t>(size_x * size_x, stream); // rmm because of resize
605605
auto eigVectorView = raft::make_device_matrix_view<value_t, index_t, raft::col_major>(eigVectorBuffer.data(), size_x, size_x);
606606
auto eigLambda = raft::make_device_vector<value_t, index_t>(handle, size_x);
607-
eigh(handle, gramXAX.view(), eigVectorView, eigLambda.view());
608-
truncEig(handle, eigVectorView, eigLambda.view(), size_x, largest);
607+
std::optional<raft::device_matrix_view<value_t, index_t, raft::col_major>> empty_matrix_opt = std::nullopt;
608+
eigh(handle, gramXAX.view(), empty_matrix_opt, eigVectorView, eigLambda.view());
609+
610+
truncEig(handle, eigVectorView, empty_matrix_opt, eigLambda.view(), largest);
609611
// Slice not needed for first eigh
610612
// raft::matrix::slice(handle, eigVectorFull, eigVector, raft::matrix::slice_coordinates(0, 0,
611613
// eigVectorFull.extent(0), size_x));
@@ -623,6 +625,9 @@ void lobpcg(
623625
auto identView = raft::make_device_matrix_view<value_t, index_t, raft::col_major>(
624626
ident.data(), size_x, size_x);
625627
raft::matrix::eye(handle, identView);
628+
auto identSizeX = raft::make_device_matrix<value_t, index_t, raft::col_major>(
629+
handle, size_x, size_x);
630+
raft::matrix::eye(handle, identSizeX.view());
626631

627632
auto Pbuffer = rmm::device_uvector<value_t>(0, stream);
628633
auto APbuffer = rmm::device_uvector<value_t>(0, stream);
@@ -646,6 +651,8 @@ void lobpcg(
646651

647652
auto aux = raft::make_device_matrix<value_t, index_t, raft::col_major>(
648653
handle, n, size_x);
654+
//auto aux_sum = raft::make_device_vector<value_t, index_t>(handle, size_x);
655+
auto residual_norms = raft::make_device_vector<value_t, index_t>(handle, size_x);
649656
std::int32_t iteration_number = -1;
650657
bool restart = true;
651658
bool explicitGramFlag = false;
@@ -664,9 +671,8 @@ void lobpcg(
664671
raft::linalg::subtract(
665672
handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view());
666673

667-
auto aux_sum = raft::make_device_vector<value_t, index_t>(handle, size_x);
668674
raft::linalg::reduce(
669-
aux_sum.data_handle(),
675+
residual_norms.data_handle(),
670676
R.data_handle(),
671677
size_x,
672678
n,
@@ -677,8 +683,7 @@ void lobpcg(
677683
false,
678684
raft::sq_op());
679685

680-
auto residual_norms = raft::make_device_vector<value_t, index_t>(handle, size_x);
681-
raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view());
686+
// TODO check sqop of reduce raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view());
682687

683688
// cupy where & active_mask
684689
raft::linalg::unary_op(handle,
@@ -720,7 +725,7 @@ void lobpcg(
720725
selectColsIf(handle, APView, active_mask.view(), activeAPView);
721726
if (B_opt.has_value()) {
722727
activeBPView = raft::make_device_matrix_view<value_t, index_t, col_major>(activeBPbuffer.data(), n, currentBlockSize);
723-
selectColsIf(handle, BPbuffer.view(), active_mask.view(), activeBPView);
728+
selectColsIf(handle, BPView, active_mask.view(), activeBPView);
724729
}
725730
}
726731
if (M_opt.has_value()) {
@@ -823,7 +828,7 @@ void lobpcg(
823828

824829
if (!B_opt.has_value()) {
825830
// Shared memory assignments to simplify the code
826-
BXView = X.view();
831+
BXView = X;
827832
activeBRView = activeR.view();
828833
if (!restart)
829834
activeBPView = activePView;
@@ -906,9 +911,9 @@ void lobpcg(
906911
auto gramB = raft::make_device_matrix<value_t, index_t, col_major>(handle, gramDim, gramDim);
907912
auto gramAView = gramA.view();
908913
auto gramBView = gramB.view();
909-
auto eigLambdaTemp = raft::make_device_vector_view<value_t, index_t>(handle, gramDim);
914+
auto eigLambdaTemp = raft::make_device_vector<value_t, index_t>(handle, gramDim);
910915
auto eigVectorTemp =
911-
raft::make_device_matrix_view<value_t, index_t, raft::col_major>(handle, gramDim, gramDim);
916+
raft::make_device_matrix<value_t, index_t, raft::col_major>(handle, gramDim, gramDim);
912917
auto eigLambdaTempView = eigLambdaTemp.view();
913918
auto eigVectorTempView = eigVectorTemp.view();
914919
eigVectorBuffer.resize(gramDim * size_x, stream);
@@ -927,19 +932,19 @@ void lobpcg(
927932
handle, currentBlockSize, currentBlockSize);
928933
// create transpose mat
929934
auto gramXAPT = raft::make_device_matrix<value_t, index_t, col_major>(
930-
handle, gramXAPT.extent(1), gramXAPT.extent(0));
935+
handle, gramXAP.extent(1), gramXAP.extent(0));
931936
auto gramXART = raft::make_device_matrix<value_t, index_t, col_major>(
932-
handle, gramXART.extent(1), gramXART.extent(0));
937+
handle, gramXAR.extent(1), gramXAR.extent(0));
933938
auto gramRAPT = raft::make_device_matrix<value_t, index_t, col_major>(
934-
handle, gramRAPT.extent(1), gramRAPT.extent(0));
939+
handle, gramRAP.extent(1), gramRAP.extent(0));
935940
auto gramXBPT = raft::make_device_matrix<value_t, index_t, col_major>(
936-
handle, gramXBPT.extent(1), gramXBPT.extent(0));
941+
handle, gramXBP.extent(1), gramXBP.extent(0));
937942
auto gramXBRT = raft::make_device_matrix<value_t, index_t, col_major>(
938-
handle, gramXBRT.extent(1), gramXBRT.extent(0));
943+
handle, gramXBR.extent(1), gramXBR.extent(0));
939944
auto gramRBPT = raft::make_device_matrix<value_t, index_t, col_major>(
940-
handle, gramRBPT.extent(1), gramRBPT.extent(0));
945+
handle, gramRBP.extent(1), gramRBP.extent(0));
941946
raft::linalg::transpose(handle, gramXAR.view(), gramXART.view());
942-
raft::linalg::transpose(handle, gramXVR.view(), gramXBRT.view());
947+
raft::linalg::transpose(handle, gramXBR.view(), gramXBRT.view());
943948

944949
if (!restart) {
945950
raft::linalg::gemm(handle,
@@ -1005,19 +1010,19 @@ void lobpcg(
10051010
gramBView =
10061011
raft::make_device_matrix_view<value_t, index_t, col_major>(gramB.data_handle(), n, n);
10071012

1008-
bmat(handle, gramAView, A_blocks);
1009-
bmat(handle, gramBView, B_blocks);
1013+
bmat(handle, gramAView, A_blocks, 3);
1014+
bmat(handle, gramBView, B_blocks, 3);
10101015

10111016
bool eig_sucess =
1012-
eigh(handle, gramA, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView);
1017+
eigh(handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView);
10131018
if (!eig_sucess) restart = true;
10141019
}
10151020
if (restart) {
10161021
gramDim = gramXAX.extent(1) + gramXAR.extent(1);
10171022
std::vector<raft::device_matrix_view<value_t, index_t, col_major>> A_blocks = {
1018-
gramXAX, gramXAR, gramXART, gramRAR};
1023+
gramXAX.view(), gramXAR.view(), gramXART.view(), gramRAR.view()};
10191024
std::vector<raft::device_matrix_view<value_t, index_t, col_major>> B_blocks = {
1020-
gramXBX, gramXBR, gramXBRT, gramRBR};
1025+
gramXBX.view(), gramXBR.view(), gramXBRT.view(), gramRBR.view()};
10211026
gramAView = raft::make_device_matrix_view<value_t, index_t, col_major>(
10221027
gramA.data_handle(), gramDim, gramDim);
10231028
gramBView = raft::make_device_matrix_view<value_t, index_t, col_major>(
@@ -1026,8 +1031,8 @@ void lobpcg(
10261031
raft::make_device_vector_view<value_t, index_t>(eigLambdaTempView.data_handle(), gramDim);
10271032
eigVectorTempView = raft::make_device_matrix_view<value_t, index_t, col_major>(
10281033
eigVectorTempView.data_handle(), gramDim, gramDim);
1029-
bmat(handle, gramAView, A_blocks);
1030-
bmat(handle, gramBView, B_blocks);
1034+
bmat(handle, gramAView, A_blocks, 2);
1035+
bmat(handle, gramBView, B_blocks, 2);
10311036
bool eig_sucess = eigh(
10321037
handle, gramAView, std::make_optional(gramBView), eigVectorTempView, eigLambdaTempView);
10331038
ASSERT(eig_sucess, "lobpcg: eigh has failed in lobpcg iterations");
@@ -1048,20 +1053,20 @@ void lobpcg(
10481053
auto app = raft::make_device_matrix<value_t, index_t, raft::col_major>(handle, n, size_x);
10491054
if (B_opt.has_value()) {
10501055
auto bpp = raft::make_device_matrix<value_t, index_t, raft::col_major>(handle, n, size_x);
1051-
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorX.view(),
1056+
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorX.view(),
10521057
raft::matrix::slice_coordinates<index_t>(0, 0, size_x, size_x));
10531058
if (!restart) {
1054-
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(),
1059+
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(),
10551060
raft::matrix::slice_coordinates<index_t>(size_x, 0, size_x + currentBlockSize, size_x));
1056-
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorP.view(),
1061+
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorP.view(),
10571062
raft::matrix::slice_coordinates<index_t>(size_x + currentBlockSize, 0, gramDim, size_x));
10581063
} else {
1059-
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(),
1064+
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(),
10601065
raft::matrix::slice_coordinates<index_t>(size_x, 0, gramDim, size_x));
10611066
}
10621067

1063-
raft::linalg::gemm(handle, activeRView, eigBlockVectorR.view(), pp.view());
1064-
raft::linalg::gemm(handle, activeARView, eigBlockVectorR.view(), app.view());
1068+
raft::linalg::gemm(handle, activeR.view(), eigBlockVectorR.view(), pp.view());
1069+
raft::linalg::gemm(handle, activeAR.view(), eigBlockVectorR.view(), app.view());
10651070
raft::linalg::gemm(handle, activeBRView, eigBlockVectorR.view(), bpp.view());
10661071
if (!restart) {
10671072
raft::linalg::gemm(handle, activePView, eigBlockVectorP.view(), pp.view(), one, one);
@@ -1087,20 +1092,20 @@ void lobpcg(
10871092
raft::copy(AX.data_handle(), app.data_handle(), app.size(), stream);
10881093
raft::copy(BXView.data_handle(), bpp.data_handle(), bpp.size(), stream);
10891094
} else {
1090-
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorX.view(),
1095+
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorX.view(),
10911096
raft::matrix::slice_coordinates<index_t>(0, 0, size_x, size_x));
10921097
if (!restart) {
1093-
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(),
1098+
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(),
10941099
raft::matrix::slice_coordinates<index_t>(size_x, 0, size_x + currentBlockSize, size_x));
1095-
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorP.view(),
1100+
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorP.view(),
10961101
raft::matrix::slice_coordinates<index_t>(size_x + currentBlockSize, 0, gramDim, size_x));
10971102
} else {
1098-
raft::matrix::slice(handle, make_const_mdpsan(eigVectorView), eigBlockVectorR.view(),
1103+
raft::matrix::slice(handle, make_const_mdspan(eigVectorView), eigBlockVectorR.view(),
10991104
raft::matrix::slice_coordinates<index_t>(size_x, 0, gramDim, size_x));
11001105
}
11011106

1102-
raft::linalg::gemm(handle, activeRView, eigBlockVectorR.view(), pp.view());
1103-
raft::linalg::gemm(handle, activeARView, eigBlockVectorR.view(), app.view());
1107+
raft::linalg::gemm(handle, activeR.view(), eigBlockVectorR.view(), pp.view());
1108+
raft::linalg::gemm(handle, activeAR.view(), eigBlockVectorR.view(), app.view());
11041109
if (!restart) {
11051110
raft::linalg::gemm(handle, activePView, eigBlockVectorP.view(), pp.view(), one, one);
11061111
raft::linalg::gemm(handle, activeAPView, eigBlockVectorP.view(), app.view(), one, one);
@@ -1121,12 +1126,31 @@ void lobpcg(
11211126
}
11221127
}
11231128

1124-
if (B_opt.has_value()) { // Using blockVectorR instead of aux
1125-
raft::copy(R.data_handle(), BXView.data_handle(), BXView.size(), stream);
1129+
if (B_opt.has_value()) {
1130+
raft::copy(aux.data_handle(), BXView.data_handle(), BXView.size(), stream);
11261131
} else {
1127-
raft::copy(R.data_handle(), X.data_handle(), X.size(), stream);
1132+
raft::copy(aux.data_handle(), X.data_handle(), X.size(), stream);
1133+
}
1134+
raft::linalg::binary_mult_skip_zero(handle, aux.view(), make_const_mdspan(eigLambda.view()), raft::linalg::Apply::ALONG_ROWS);
1135+
1136+
raft::linalg::subtract(
1137+
handle, raft::make_const_mdspan(AX.view()), raft::make_const_mdspan(aux.view()), R.view());
1138+
1139+
raft::linalg::reduce(
1140+
residual_norms.data_handle(),
1141+
R.data_handle(),
1142+
size_x,
1143+
n,
1144+
value_t(0),
1145+
false,
1146+
true,
1147+
stream,
1148+
false,
1149+
raft::sq_op());
1150+
// TODO check reduce sqrt postop raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view());
1151+
1152+
if (verbosityLevel > 0) {
1153+
/// TODO add verb
11281154
}
1129-
raft::linalg::binary_mult_skip_zero(handle, R.view(), make_const_mdspan(eigLambda.view()), linalg::Apply::ALONG_ROWS);
1130-
raft::linalg::gemm(handle, AX.view(),)
11311155
}
11321156
}; // namespace raft::sparse::solver::detail

0 commit comments

Comments
 (0)