@@ -724,18 +724,19 @@ class ColumnSplitHelper {
724724 SparsePageView data{ctx_, batch, num_features};
725725 auto const grid = static_cast <uint32_t >(common::DivRoundUp (num_rows, kBlockThreads ));
726726 auto d_tree_groups = d_model.tree_groups ;
727- dh::LaunchKernel {grid, kBlockThreads , shared_memory_bytes, ctx_->CUDACtx ()->Stream ()}(
727+ dh::LaunchKernel{grid, kBlockThreads , shared_memory_bytes, // NOLINT(whitespace/braces)
728+ ctx_->CUDACtx ()->Stream ()}(
728729 MaskBitVectorKernel, data, d_model.Trees (), decision_bits, missing_bits,
729730 d_model.tree_begin , d_model.tree_end , num_features, num_nodes, use_shared,
730731 std::numeric_limits<float >::quiet_NaN ());
731732
732733 AllReduceBitVectors (&decision_storage, &missing_storage);
733734
734- dh::LaunchKernel {grid, kBlockThreads , 0 , ctx_->CUDACtx ()->Stream ()}(
735+ dh::LaunchKernel{grid, kBlockThreads , 0 , // NOLINT(whitespace/braces)
736+ ctx_->CUDACtx ()->Stream ()}(
735737 PredictByBitVectorKernel<predict_leaf>, d_model.Trees (),
736- out_preds->DeviceSpan ().subspan (batch_offset), d_tree_groups,
737- decision_bits, missing_bits, d_model.tree_begin , d_model.tree_end , num_rows, num_nodes,
738- num_group);
738+ out_preds->DeviceSpan ().subspan (batch_offset), d_tree_groups, decision_bits, missing_bits,
739+ d_model.tree_begin , d_model.tree_end , num_rows, num_nodes, num_group);
739740
740741 batch_offset += batch.Size () * num_group;
741742 }
@@ -858,8 +859,7 @@ class LaunchConfig {
858859 }
859860
860861 public:
861- LaunchConfig (Context const * ctx, bst_feature_t n_features)
862- : ctx_{ctx}, n_features_{n_features} {}
862+ LaunchConfig (Context const * ctx, bst_feature_t n_features) : ctx_{ctx}, n_features_{n_features} {}
863863
864864 template <typename Fn>
865865 void ForEachBatch (DMatrix* p_fmat, Fn&& fn) {
@@ -974,8 +974,8 @@ class GPUPredictor : public xgboost::Predictor {
974974 out_preds->SetDevice (ctx_->Device ());
975975 auto const & info = p_fmat->Info ();
976976
977- DeviceModel d_model{this ->ctx_ ->Device (), model, false , tree_begin, tree_end, & this -> model_mu_ ,
978- CopyViews{this ->ctx_ }};
977+ DeviceModel d_model{this ->ctx_ ->Device (), model, false ,
978+ tree_begin, tree_end, CopyViews{this ->ctx_ }};
979979
980980 if (info.IsColumnSplit ()) {
981981 column_split_helper_.PredictBatch (p_fmat, out_preds, model, d_model);
@@ -1036,8 +1036,7 @@ class GPUPredictor : public xgboost::Predictor {
10361036 auto n_samples = m->NumRows ();
10371037 auto n_features = model.learner_model_param ->num_feature ;
10381038
1039- DeviceModel d_model{ctx_->Device (), model, false , tree_begin, tree_end, &this ->model_mu_ ,
1040- CopyViews{this ->ctx_ }};
1039+ DeviceModel d_model{ctx_->Device (), model, false , tree_begin, tree_end, CopyViews{this ->ctx_ }};
10411040
10421041 if constexpr (std::is_same_v<Adapter, data::CudfAdapter>) {
10431042 if (m->HasCategorical ()) {
@@ -1055,18 +1054,16 @@ class GPUPredictor : public xgboost::Predictor {
10551054 }
10561055 }
10571056
1058- LaunchPredict (this ->ctx_ , false , enc::DeviceColumnsView{}, model,
1059- [&](auto && cfg, auto && acc) {
1060- using EncAccessor = std::remove_reference_t <decltype (acc)>;
1061- CHECK ((std::is_same_v<EncAccessor, NoOpAccessor>));
1062- using LoaderImpl = DeviceAdapterLoader<BatchT, EncAccessor>;
1063- using Loader =
1064- typename common::GetValueT<decltype (cfg)>::template LoaderType<LoaderImpl,
1065- 128 >;
1066- cfg.template AllocShmem <Loader>();
1067- cfg.template LaunchPredictKernel <Loader>(
1068- m->Value (), missing, n_features, d_model, acc, 0 , &out_preds->predictions );
1069- });
1057+ LaunchPredict (this ->ctx_ , false , enc::DeviceColumnsView{}, model, [&](auto && cfg, auto && acc) {
1058+ using EncAccessor = std::remove_reference_t <decltype (acc)>;
1059+ CHECK ((std::is_same_v<EncAccessor, NoOpAccessor>));
1060+ using LoaderImpl = DeviceAdapterLoader<BatchT, EncAccessor>;
1061+ using Loader =
1062+ typename common::GetValueT<decltype (cfg)>::template LoaderType<LoaderImpl, 128 >;
1063+ cfg.template AllocShmem <Loader>();
1064+ cfg.template LaunchPredictKernel <Loader>(m->Value (), missing, n_features, d_model, acc, 0 ,
1065+ &out_preds->predictions );
1066+ });
10701067 }
10711068
10721069 [[nodiscard]] bool InplacePredict (std::shared_ptr<DMatrix> p_m, gbm::GBTreeModel const & model,
@@ -1116,8 +1113,7 @@ class GPUPredictor : public xgboost::Predictor {
11161113 auto phis = out_contribs->DeviceSpan ();
11171114
11181115 dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>> device_paths;
1119- DeviceModel d_model{this ->ctx_ ->Device (), model, true , 0 , tree_end, &this ->model_mu_ ,
1120- CopyViews{this ->ctx_ }};
1116+ DeviceModel d_model{this ->ctx_ ->Device (), model, true , 0 , tree_end, CopyViews{this ->ctx_ }};
11211117
11221118 auto new_enc =
11231119 p_fmat->Cats ()->NeedRecode () ? p_fmat->Cats ()->DeviceView (ctx_) : enc::DeviceColumnsView{};
@@ -1177,8 +1173,7 @@ class GPUPredictor : public xgboost::Predictor {
11771173 auto phis = out_contribs->DeviceSpan ();
11781174
11791175 dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>> device_paths;
1180- DeviceModel d_model{this ->ctx_ ->Device (), model, true , 0 , tree_end, &this ->model_mu_ ,
1181- CopyViews{this ->ctx_ }};
1176+ DeviceModel d_model{this ->ctx_ ->Device (), model, true , 0 , tree_end, CopyViews{this ->ctx_ }};
11821177
11831178 dh::device_vector<uint32_t > categories;
11841179 ExtractPaths (ctx_, &device_paths, model, d_model, &categories);
@@ -1223,8 +1218,7 @@ class GPUPredictor : public xgboost::Predictor {
12231218 predictions->SetDevice (ctx_->Device ());
12241219 predictions->Resize (n_samples * tree_end);
12251220
1226- DeviceModel d_model{ctx_->Device (), model, false , 0 , tree_end, &this ->model_mu_ ,
1227- CopyViews{this ->ctx_ }};
1221+ DeviceModel d_model{ctx_->Device (), model, false , 0 , tree_end, CopyViews{this ->ctx_ }};
12281222
12291223 if (info.IsColumnSplit ()) {
12301224 column_split_helper_.PredictLeaf (p_fmat, predictions, model, d_model);
@@ -1254,8 +1248,6 @@ class GPUPredictor : public xgboost::Predictor {
12541248 }
12551249
12561250 private:
1257- // Prevent multiple threads from pulling the model to device together.
1258- mutable std::mutex model_mu_;
12591251 ColumnSplitHelper column_split_helper_;
12601252};
12611253
0 commit comments