@@ -135,9 +135,9 @@ void selectColsIf(const raft::handle_t& handle,
135
135
raft::linalg::map (
136
136
handle,
137
137
raft::make_const_mdspan (mask),
138
+ raft::make_const_mdspan (rangeVec.view ()),
138
139
rangeVec.view (),
139
- [] __device__ (index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1 ; },
140
- rangeVec.view ());
140
+ [] __device__ (index_t mask_value, index_t idx) { return mask_value == 1 ? idx : -1 ; });
141
141
thrust::sort (rmm::exec_policy (stream),
142
142
rangeVec.data_handle (),
143
143
rangeVec.data_handle () + rangeVec.size (),
@@ -172,11 +172,11 @@ void truncEig(
172
172
}
173
173
if (eigVectorTrunc.has_value () && ncols > eigVectorTrunc->extent (1 ))
174
174
raft::matrix::truncZeroOrigin (eigVectorin.data_handle (),
175
- n_rows ,
175
+ nrows ,
176
176
eigVectorTrunc->data_handle (),
177
177
nrows,
178
178
eigVectorTrunc->extent (1 ),
179
- stream );
179
+ handle. get_stream () );
180
180
}
181
181
182
182
// C = A * B
@@ -447,7 +447,7 @@ bool eigh(const raft::handle_t& handle,
447
447
448
448
raft::linalg::eig_dc (handle, raft::make_const_mdspan (F.view ()), Fvecs.view (), eigVals);
449
449
raft::linalg::gemm (handle, Ri.view (), Fvecs.view (), eigVecs);
450
- return cho_success
450
+ return cho_success;
451
451
}
452
452
453
453
/* *
@@ -604,8 +604,10 @@ void lobpcg(
604
604
auto eigVectorBuffer = rmm::device_uvector<value_t >(size_x * size_x, stream); // rmm because of resize
605
605
auto eigVectorView = raft::make_device_matrix_view<value_t , index_t , raft::col_major>(eigVectorBuffer.data (), size_x, size_x);
606
606
auto eigLambda = raft::make_device_vector<value_t , index_t >(handle, size_x);
607
- eigh (handle, gramXAX.view (), eigVectorView, eigLambda.view ());
608
- truncEig (handle, eigVectorView, eigLambda.view (), size_x, largest);
607
+ std::optional<raft::device_matrix_view<value_t , index_t , raft::col_major>> empty_matrix_opt = std::nullopt;
608
+ eigh (handle, gramXAX.view (), empty_matrix_opt, eigVectorView, eigLambda.view ());
609
+
610
+ truncEig (handle, eigVectorView, empty_matrix_opt, eigLambda.view (), largest);
609
611
// Slice not needed for first eigh
610
612
// raft::matrix::slice(handle, eigVectorFull, eigVector, raft::matrix::slice_coordinates(0, 0,
611
613
// eigVectorFull.extent(0), size_x));
@@ -623,6 +625,9 @@ void lobpcg(
623
625
auto identView = raft::make_device_matrix_view<value_t , index_t , raft::col_major>(
624
626
ident.data (), size_x, size_x);
625
627
raft::matrix::eye (handle, identView);
628
+ auto identSizeX = raft::make_device_matrix<value_t , index_t , raft::col_major>(
629
+ handle, size_x, size_x);
630
+ raft::matrix::eye (handle, identSizeX.view ());
626
631
627
632
auto Pbuffer = rmm::device_uvector<value_t >(0 , stream);
628
633
auto APbuffer = rmm::device_uvector<value_t >(0 , stream);
@@ -646,6 +651,8 @@ void lobpcg(
646
651
647
652
auto aux = raft::make_device_matrix<value_t , index_t , raft::col_major>(
648
653
handle, n, size_x);
654
+ // auto aux_sum = raft::make_device_vector<value_t, index_t>(handle, size_x);
655
+ auto residual_norms = raft::make_device_vector<value_t , index_t >(handle, size_x);
649
656
std::int32_t iteration_number = -1 ;
650
657
bool restart = true ;
651
658
bool explicitGramFlag = false ;
@@ -664,9 +671,8 @@ void lobpcg(
664
671
raft::linalg::subtract (
665
672
handle, raft::make_const_mdspan (AX.view ()), raft::make_const_mdspan (aux.view ()), R.view ());
666
673
667
- auto aux_sum = raft::make_device_vector<value_t , index_t >(handle, size_x);
668
674
raft::linalg::reduce (
669
- aux_sum .data_handle (),
675
+ residual_norms .data_handle (),
670
676
R.data_handle (),
671
677
size_x,
672
678
n,
@@ -677,8 +683,7 @@ void lobpcg(
677
683
false ,
678
684
raft::sq_op ());
679
685
680
- auto residual_norms = raft::make_device_vector<value_t , index_t >(handle, size_x);
681
- raft::linalg::sqrt (handle, raft::make_const_mdspan (aux_sum.view ()), residual_norms.view ());
686
+ // TODO check sqop of reduce raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view());
682
687
683
688
// cupy where & active_mask
684
689
raft::linalg::unary_op (handle,
@@ -720,7 +725,7 @@ void lobpcg(
720
725
selectColsIf (handle, APView, active_mask.view (), activeAPView);
721
726
if (B_opt.has_value ()) {
722
727
activeBPView = raft::make_device_matrix_view<value_t , index_t , col_major>(activeBPbuffer.data (), n, currentBlockSize);
723
- selectColsIf (handle, BPbuffer. view () , active_mask.view (), activeBPView);
728
+ selectColsIf (handle, BPView , active_mask.view (), activeBPView);
724
729
}
725
730
}
726
731
if (M_opt.has_value ()) {
@@ -823,7 +828,7 @@ void lobpcg(
823
828
824
829
if (!B_opt.has_value ()) {
825
830
// Shared memory assignments to simplify the code
826
- BXView = X. view () ;
831
+ BXView = X;
827
832
activeBRView = activeR.view ();
828
833
if (!restart)
829
834
activeBPView = activePView;
@@ -906,9 +911,9 @@ void lobpcg(
906
911
auto gramB = raft::make_device_matrix<value_t , index_t , col_major>(handle, gramDim, gramDim);
907
912
auto gramAView = gramA.view ();
908
913
auto gramBView = gramB.view ();
909
- auto eigLambdaTemp = raft::make_device_vector_view <value_t , index_t >(handle, gramDim);
914
+ auto eigLambdaTemp = raft::make_device_vector <value_t , index_t >(handle, gramDim);
910
915
auto eigVectorTemp =
911
- raft::make_device_matrix_view <value_t , index_t , raft::col_major>(handle, gramDim, gramDim);
916
+ raft::make_device_matrix <value_t , index_t , raft::col_major>(handle, gramDim, gramDim);
912
917
auto eigLambdaTempView = eigLambdaTemp.view ();
913
918
auto eigVectorTempView = eigVectorTemp.view ();
914
919
eigVectorBuffer.resize (gramDim * size_x, stream);
@@ -927,19 +932,19 @@ void lobpcg(
927
932
handle, currentBlockSize, currentBlockSize);
928
933
// create transpose mat
929
934
auto gramXAPT = raft::make_device_matrix<value_t , index_t , col_major>(
930
- handle, gramXAPT .extent (1 ), gramXAPT .extent (0 ));
935
+ handle, gramXAP .extent (1 ), gramXAP .extent (0 ));
931
936
auto gramXART = raft::make_device_matrix<value_t , index_t , col_major>(
932
- handle, gramXART .extent (1 ), gramXART .extent (0 ));
937
+ handle, gramXAR .extent (1 ), gramXAR .extent (0 ));
933
938
auto gramRAPT = raft::make_device_matrix<value_t , index_t , col_major>(
934
- handle, gramRAPT .extent (1 ), gramRAPT .extent (0 ));
939
+ handle, gramRAP .extent (1 ), gramRAP .extent (0 ));
935
940
auto gramXBPT = raft::make_device_matrix<value_t , index_t , col_major>(
936
- handle, gramXBPT .extent (1 ), gramXBPT .extent (0 ));
941
+ handle, gramXBP .extent (1 ), gramXBP .extent (0 ));
937
942
auto gramXBRT = raft::make_device_matrix<value_t , index_t , col_major>(
938
- handle, gramXBRT .extent (1 ), gramXBRT .extent (0 ));
943
+ handle, gramXBR .extent (1 ), gramXBR .extent (0 ));
939
944
auto gramRBPT = raft::make_device_matrix<value_t , index_t , col_major>(
940
- handle, gramRBPT .extent (1 ), gramRBPT .extent (0 ));
945
+ handle, gramRBP .extent (1 ), gramRBP .extent (0 ));
941
946
raft::linalg::transpose (handle, gramXAR.view (), gramXART.view ());
942
- raft::linalg::transpose (handle, gramXVR .view (), gramXBRT.view ());
947
+ raft::linalg::transpose (handle, gramXBR .view (), gramXBRT.view ());
943
948
944
949
if (!restart) {
945
950
raft::linalg::gemm (handle,
@@ -1005,19 +1010,19 @@ void lobpcg(
1005
1010
gramBView =
1006
1011
raft::make_device_matrix_view<value_t , index_t , col_major>(gramB.data_handle (), n, n);
1007
1012
1008
- bmat (handle, gramAView, A_blocks);
1009
- bmat (handle, gramBView, B_blocks);
1013
+ bmat (handle, gramAView, A_blocks, 3 );
1014
+ bmat (handle, gramBView, B_blocks, 3 );
1010
1015
1011
1016
bool eig_sucess =
1012
- eigh (handle, gramA , std::make_optional (gramBView), eigVectorTempView, eigLambdaTempView);
1017
+ eigh (handle, gramAView , std::make_optional (gramBView), eigVectorTempView, eigLambdaTempView);
1013
1018
if (!eig_sucess) restart = true ;
1014
1019
}
1015
1020
if (restart) {
1016
1021
gramDim = gramXAX.extent (1 ) + gramXAR.extent (1 );
1017
1022
std::vector<raft::device_matrix_view<value_t , index_t , col_major>> A_blocks = {
1018
- gramXAX, gramXAR, gramXART, gramRAR};
1023
+ gramXAX. view () , gramXAR. view () , gramXART. view () , gramRAR. view () };
1019
1024
std::vector<raft::device_matrix_view<value_t , index_t , col_major>> B_blocks = {
1020
- gramXBX, gramXBR, gramXBRT, gramRBR};
1025
+ gramXBX. view () , gramXBR. view () , gramXBRT. view () , gramRBR. view () };
1021
1026
gramAView = raft::make_device_matrix_view<value_t , index_t , col_major>(
1022
1027
gramA.data_handle (), gramDim, gramDim);
1023
1028
gramBView = raft::make_device_matrix_view<value_t , index_t , col_major>(
@@ -1026,8 +1031,8 @@ void lobpcg(
1026
1031
raft::make_device_vector_view<value_t , index_t >(eigLambdaTempView.data_handle (), gramDim);
1027
1032
eigVectorTempView = raft::make_device_matrix_view<value_t , index_t , col_major>(
1028
1033
eigVectorTempView.data_handle (), gramDim, gramDim);
1029
- bmat (handle, gramAView, A_blocks);
1030
- bmat (handle, gramBView, B_blocks);
1034
+ bmat (handle, gramAView, A_blocks, 2 );
1035
+ bmat (handle, gramBView, B_blocks, 2 );
1031
1036
bool eig_sucess = eigh (
1032
1037
handle, gramAView, std::make_optional (gramBView), eigVectorTempView, eigLambdaTempView);
1033
1038
ASSERT (eig_sucess, " lobpcg: eigh has failed in lobpcg iterations" );
@@ -1048,20 +1053,20 @@ void lobpcg(
1048
1053
auto app = raft::make_device_matrix<value_t , index_t , raft::col_major>(handle, n, size_x);
1049
1054
if (B_opt.has_value ()) {
1050
1055
auto bpp = raft::make_device_matrix<value_t , index_t , raft::col_major>(handle, n, size_x);
1051
- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorX.view (),
1056
+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorX.view (),
1052
1057
raft::matrix::slice_coordinates<index_t >(0 , 0 , size_x, size_x));
1053
1058
if (!restart) {
1054
- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorR.view (),
1059
+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorR.view (),
1055
1060
raft::matrix::slice_coordinates<index_t >(size_x, 0 , size_x + currentBlockSize, size_x));
1056
- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorP.view (),
1061
+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorP.view (),
1057
1062
raft::matrix::slice_coordinates<index_t >(size_x + currentBlockSize, 0 , gramDim, size_x));
1058
1063
} else {
1059
- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorR.view (),
1064
+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorR.view (),
1060
1065
raft::matrix::slice_coordinates<index_t >(size_x, 0 , gramDim, size_x));
1061
1066
}
1062
1067
1063
- raft::linalg::gemm (handle, activeRView , eigBlockVectorR.view (), pp.view ());
1064
- raft::linalg::gemm (handle, activeARView , eigBlockVectorR.view (), app.view ());
1068
+ raft::linalg::gemm (handle, activeR. view () , eigBlockVectorR.view (), pp.view ());
1069
+ raft::linalg::gemm (handle, activeAR. view () , eigBlockVectorR.view (), app.view ());
1065
1070
raft::linalg::gemm (handle, activeBRView, eigBlockVectorR.view (), bpp.view ());
1066
1071
if (!restart) {
1067
1072
raft::linalg::gemm (handle, activePView, eigBlockVectorP.view (), pp.view (), one, one);
@@ -1087,20 +1092,20 @@ void lobpcg(
1087
1092
raft::copy (AX.data_handle (), app.data_handle (), app.size (), stream);
1088
1093
raft::copy (BXView.data_handle (), bpp.data_handle (), bpp.size (), stream);
1089
1094
} else {
1090
- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorX.view (),
1095
+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorX.view (),
1091
1096
raft::matrix::slice_coordinates<index_t >(0 , 0 , size_x, size_x));
1092
1097
if (!restart) {
1093
- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorR.view (),
1098
+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorR.view (),
1094
1099
raft::matrix::slice_coordinates<index_t >(size_x, 0 , size_x + currentBlockSize, size_x));
1095
- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorP.view (),
1100
+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorP.view (),
1096
1101
raft::matrix::slice_coordinates<index_t >(size_x + currentBlockSize, 0 , gramDim, size_x));
1097
1102
} else {
1098
- raft::matrix::slice (handle, make_const_mdpsan (eigVectorView), eigBlockVectorR.view (),
1103
+ raft::matrix::slice (handle, make_const_mdspan (eigVectorView), eigBlockVectorR.view (),
1099
1104
raft::matrix::slice_coordinates<index_t >(size_x, 0 , gramDim, size_x));
1100
1105
}
1101
1106
1102
- raft::linalg::gemm (handle, activeRView , eigBlockVectorR.view (), pp.view ());
1103
- raft::linalg::gemm (handle, activeARView , eigBlockVectorR.view (), app.view ());
1107
+ raft::linalg::gemm (handle, activeR. view () , eigBlockVectorR.view (), pp.view ());
1108
+ raft::linalg::gemm (handle, activeAR. view () , eigBlockVectorR.view (), app.view ());
1104
1109
if (!restart) {
1105
1110
raft::linalg::gemm (handle, activePView, eigBlockVectorP.view (), pp.view (), one, one);
1106
1111
raft::linalg::gemm (handle, activeAPView, eigBlockVectorP.view (), app.view (), one, one);
@@ -1121,12 +1126,31 @@ void lobpcg(
1121
1126
}
1122
1127
}
1123
1128
1124
- if (B_opt.has_value ()) { // Using blockVectorR instead of aux
1125
- raft::copy (R .data_handle (), BXView.data_handle (), BXView.size (), stream);
1129
+ if (B_opt.has_value ()) {
1130
+ raft::copy (aux .data_handle (), BXView.data_handle (), BXView.size (), stream);
1126
1131
} else {
1127
- raft::copy (R.data_handle (), X.data_handle (), X.size (), stream);
1132
+ raft::copy (aux.data_handle (), X.data_handle (), X.size (), stream);
1133
+ }
1134
+ raft::linalg::binary_mult_skip_zero (handle, aux.view (), make_const_mdspan (eigLambda.view ()), raft::linalg::Apply::ALONG_ROWS);
1135
+
1136
+ raft::linalg::subtract (
1137
+ handle, raft::make_const_mdspan (AX.view ()), raft::make_const_mdspan (aux.view ()), R.view ());
1138
+
1139
+ raft::linalg::reduce (
1140
+ residual_norms.data_handle (),
1141
+ R.data_handle (),
1142
+ size_x,
1143
+ n,
1144
+ value_t (0 ),
1145
+ false ,
1146
+ true ,
1147
+ stream,
1148
+ false ,
1149
+ raft::sq_op ());
1150
+ // TODO check reduce sqrt postop raft::linalg::sqrt(handle, raft::make_const_mdspan(aux_sum.view()), residual_norms.view());
1151
+
1152
+ if (verbosityLevel > 0 ) {
1153
+ // / TODO add verb
1128
1154
}
1129
- raft::linalg::binary_mult_skip_zero (handle, R.view (), make_const_mdspan (eigLambda.view ()), linalg::Apply::ALONG_ROWS);
1130
- raft::linalg::gemm (handle, AX.view (),)
1131
1155
}
1132
1156
}; // namespace raft::sparse::solver::detail
0 commit comments