Skip to content

Commit e22ec68

Browse files
author
Dmitry Razdoburdin
committed
minor dispatcher refactor
1 parent 13bec86 commit e22ec68

File tree

6 files changed

+167
-168
lines changed

6 files changed

+167
-168
lines changed

plugin/sycl/common/hist_util.cc

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,29 @@ template ::sycl::event SubtractionHist(::sycl::queue* qu,
9292
const GHistRow<double, MemoryType::on_device>& src2,
9393
size_t size, ::sycl::event event_priv);
9494

95+
template <typename GradientPairT>
96+
::sycl::event ReduceHist(::sycl::queue* qu, GradientPairT* hist_data,
97+
GradientPairT* hist_buffer_data,
98+
size_t nblocks, size_t nbins,
99+
const ::sycl::event& event_main) {
100+
auto event_save = qu->submit([&](::sycl::handler& cgh) {
101+
cgh.depends_on(event_main);
102+
cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) {
103+
size_t idx_bin = pid.get_id(0);
104+
105+
GradientPairT gpair = {0, 0};
106+
107+
for (size_t j = 0; j < nblocks; ++j) {
108+
gpair += hist_buffer_data[j * nbins + idx_bin];
109+
}
110+
111+
hist_data[idx_bin] = gpair;
112+
});
113+
});
114+
115+
return event_save;
116+
}
117+
95118
// Kernel with buffer using
96119
template<typename FPType, typename BinIdxType, bool isDense>
97120
::sycl::event BuildHistKernel(::sycl::queue* qu,
@@ -100,7 +123,7 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
100123
const GHistIndexMatrix& gmat,
101124
GHistRow<FPType, MemoryType::on_device>* hist,
102125
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
103-
const tree::HistBuildParameters& params,
126+
const tree::HistDispatcher<FPType>& dispatcher,
104127
::sycl::event event_priv) {
105128
using GradientPairT = xgboost::detail::GradientPairInternal<FPType>;
106129
const size_t size = row_indices.Size();
@@ -111,9 +134,9 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
111134
const uint32_t* offsets = gmat.cut.cut_ptrs_.ConstDevicePointer();
112135
const size_t nbins = gmat.nbins;
113136

114-
const size_t work_group_size = params.work_group_size;
115-
const size_t block_size = params.block.size;
116-
const size_t nblocks = params.block.nblocks;
137+
const size_t work_group_size = dispatcher.work_group_size;
138+
const size_t block_size = dispatcher.block.size;
139+
const size_t nblocks = dispatcher.block.nblocks;
117140

118141
GradientPairT* hist_buffer_data = hist_buffer->Data();
119142
auto event_fill = qu->fill(hist_buffer_data, GradientPairT(0, 0),
@@ -152,20 +175,9 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
152175
});
153176

154177
GradientPairT* hist_data = hist->Data();
155-
auto event_save = qu->submit([&](::sycl::handler& cgh) {
156-
cgh.depends_on(event_main);
157-
cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) {
158-
size_t idx_bin = pid.get_id(0);
159-
160-
GradientPairT gpair = {0, 0};
161-
162-
for (size_t j = 0; j < nblocks; ++j) {
163-
gpair += hist_buffer_data[j * nbins + idx_bin];
164-
}
178+
auto event_save = ReduceHist(qu, hist_data, hist_buffer_data, nblocks,
179+
nbins, event_main);
165180

166-
hist_data[idx_bin] = gpair;
167-
});
168-
});
169181
return event_save;
170182
}
171183

@@ -177,9 +189,9 @@ ::sycl::event BuildHistKernelLocal(::sycl::queue* qu,
177189
const GHistIndexMatrix& gmat,
178190
GHistRow<FPType, MemoryType::on_device>* hist,
179191
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
180-
const tree::HistBuildParameters& params,
192+
const tree::HistDispatcher<FPType>& dispatcher,
181193
::sycl::event event_priv) {
182-
constexpr int kMaxNumBins = tree::HistDispatcher::KMaxNumBins;
194+
constexpr int kMaxNumBins = tree::HistDispatcher<FPType>::KMaxNumBins;
183195
using GradientPairT = xgboost::detail::GradientPairInternal<FPType>;
184196
const size_t size = row_indices.Size();
185197
const size_t* rid = row_indices.begin;
@@ -189,9 +201,9 @@ ::sycl::event BuildHistKernelLocal(::sycl::queue* qu,
189201
const uint32_t* offsets = gmat.cut.cut_ptrs_.ConstDevicePointer();
190202
const size_t nbins = gmat.nbins;
191203

192-
const size_t work_group_size = params.work_group_size;
193-
const size_t block_size = params.block.size;
194-
const size_t nblocks = params.block.nblocks;
204+
const size_t work_group_size = dispatcher.work_group_size;
205+
const size_t block_size = dispatcher.block.size;
206+
const size_t nblocks = dispatcher.block.nblocks;
195207

196208
GradientPairT* hist_buffer_data = hist_buffer->Data();
197209

@@ -239,20 +251,8 @@ ::sycl::event BuildHistKernelLocal(::sycl::queue* qu,
239251
});
240252

241253
GradientPairT* hist_data = hist->Data();
242-
auto event_save = qu->submit([&](::sycl::handler& cgh) {
243-
cgh.depends_on(event_main);
244-
cgh.parallel_for<>(::sycl::range<1>(nbins), [=](::sycl::item<1> pid) {
245-
size_t idx_bin = pid.get_id(0);
246-
247-
GradientPairT gpair = {0, 0};
248-
249-
for (size_t j = 0; j < nblocks; ++j) {
250-
gpair += hist_buffer_data[j * nbins + idx_bin];
251-
}
252-
253-
hist_data[idx_bin] = gpair;
254-
});
255-
});
254+
auto event_save = ReduceHist(qu, hist_data, hist_buffer_data, nblocks,
255+
nbins, event_main);
256256
return event_save;
257257
}
258258

@@ -263,7 +263,7 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
263263
const RowSetCollection::Elem& row_indices,
264264
const GHistIndexMatrix& gmat,
265265
GHistRow<FPType, MemoryType::on_device>* hist,
266-
const tree::HistBuildParameters& params,
266+
const tree::HistDispatcher<FPType>& dispatcher,
267267
::sycl::event event_priv) {
268268
const size_t size = row_indices.Size();
269269
const size_t* rid = row_indices.begin;
@@ -275,7 +275,7 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
275275
FPType* hist_data = reinterpret_cast<FPType*>(hist->Data());
276276
const size_t nbins = gmat.nbins;
277277

278-
size_t work_group_size = params.work_group_size;
278+
size_t work_group_size = dispatcher.work_group_size;
279279
const size_t n_work_groups = n_columns / work_group_size + (n_columns % work_group_size > 0);
280280

281281
auto event_fill = qu->fill(hist_data, FPType(0), nbins * 2, event_priv);
@@ -321,47 +321,47 @@ ::sycl::event BuildHistDispatchKernel(
321321
GHistRow<FPType, MemoryType::on_device>* hist,
322322
bool isDense,
323323
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
324-
const tree::HistDispatcher& dispatcher,
324+
const tree::DeviceProperties& device_prop,
325325
::sycl::event events_priv,
326326
bool force_atomic_use) {
327327
const size_t size = row_indices.Size();
328328
const size_t n_columns = isDense ? gmat.nfeatures : gmat.row_stride;
329329
const size_t nbins = gmat.nbins;
330330
const size_t max_num_bins = gmat.max_num_bins;
331331
const size_t min_num_bins = gmat.min_num_bins;
332-
using GradientPairT = xgboost::detail::GradientPairInternal<FPType>;
333332

334-
size_t max_n_blocks = hist_buffer->Size() / (nbins * 2);
335-
auto build_params = dispatcher.GetHistBuildParameters<GradientPairT>
336-
(isDense, size, max_n_blocks, nbins, n_columns, max_num_bins, min_num_bins);
333+
size_t max_n_blocks = hist_buffer->Size() / nbins;
334+
auto dispatcher = tree::HistDispatcher<FPType>
335+
(device_prop, isDense, size, max_n_blocks, nbins,
336+
n_columns, max_num_bins, min_num_bins);
337337

338338
// force_atomic_use flag is used only for testing
339-
bool use_atomic = build_params.use_atomics || force_atomic_use;
339+
bool use_atomic = dispatcher.use_atomics || force_atomic_use;
340340
if (!use_atomic) {
341341
if (isDense) {
342-
if (build_params.use_local_hist) {
342+
if (dispatcher.use_local_hist) {
343343
return BuildHistKernelLocal<FPType, BinIdxType>(qu, gpair, row_indices,
344344
gmat, hist, hist_buffer,
345-
build_params, events_priv);
345+
dispatcher, events_priv);
346346
} else {
347347
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair, row_indices,
348348
gmat, hist, hist_buffer,
349-
build_params, events_priv);
349+
dispatcher, events_priv);
350350
}
351351
} else {
352352
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair, row_indices,
353353
gmat, hist, hist_buffer,
354-
build_params, events_priv);
354+
dispatcher, events_priv);
355355
}
356356
} else {
357357
if (isDense) {
358358
return BuildHistKernel<FPType, BinIdxType, true>(qu, gpair, row_indices,
359359
gmat, hist,
360-
build_params, events_priv);
360+
dispatcher, events_priv);
361361
} else {
362362
return BuildHistKernel<FPType, uint32_t, false>(qu, gpair, row_indices,
363363
gmat, hist,
364-
build_params, events_priv);
364+
dispatcher, events_priv);
365365
}
366366
}
367367
}
@@ -373,27 +373,27 @@ ::sycl::event BuildHistKernel(::sycl::queue* qu,
373373
const GHistIndexMatrix& gmat, const bool isDense,
374374
GHistRow<FPType, MemoryType::on_device>* hist,
375375
GHistRow<FPType, MemoryType::on_device>* hist_buffer,
376-
const tree::HistDispatcher& dispatcher,
376+
const tree::DeviceProperties& device_prop,
377377
::sycl::event event_priv,
378378
bool force_atomic_use) {
379379
const bool is_dense = isDense;
380380
switch (gmat.index.GetBinTypeSize()) {
381381
case BinTypeSize::kUint8BinsTypeSize:
382382
return BuildHistDispatchKernel<FPType, uint8_t>(qu, gpair, row_indices,
383383
gmat, hist, is_dense, hist_buffer,
384-
dispatcher,
384+
device_prop,
385385
event_priv, force_atomic_use);
386386
break;
387387
case BinTypeSize::kUint16BinsTypeSize:
388388
return BuildHistDispatchKernel<FPType, uint16_t>(qu, gpair, row_indices,
389389
gmat, hist, is_dense, hist_buffer,
390-
dispatcher,
390+
device_prop,
391391
event_priv, force_atomic_use);
392392
break;
393393
case BinTypeSize::kUint32BinsTypeSize:
394394
return BuildHistDispatchKernel<FPType, uint32_t>(qu, gpair, row_indices,
395395
gmat, hist, is_dense, hist_buffer,
396-
dispatcher,
396+
device_prop,
397397
event_priv, force_atomic_use);
398398
break;
399399
default:
@@ -409,12 +409,12 @@ ::sycl::event GHistBuilder<GradientSumT>::BuildHist(
409409
GHistRowT<MemoryType::on_device>* hist,
410410
bool isDense,
411411
GHistRowT<MemoryType::on_device>* hist_buffer,
412-
const tree::HistDispatcher& dispatcher,
412+
const tree::DeviceProperties& device_prop,
413413
::sycl::event event_priv,
414414
bool force_atomic_use) {
415415
return BuildHistKernel<GradientSumT>(qu_, gpair, row_indices, gmat,
416416
isDense, hist, hist_buffer,
417-
dispatcher, event_priv,
417+
device_prop, event_priv,
418418
force_atomic_use);
419419
}
420420

@@ -426,7 +426,7 @@ ::sycl::event GHistBuilder<float>::BuildHist(
426426
GHistRow<float, MemoryType::on_device>* hist,
427427
bool isDense,
428428
GHistRow<float, MemoryType::on_device>* hist_buffer,
429-
const tree::HistDispatcher& dispatcher,
429+
const tree::DeviceProperties& device_prop,
430430
::sycl::event event_priv,
431431
bool force_atomic_use);
432432
template
@@ -437,7 +437,7 @@ ::sycl::event GHistBuilder<double>::BuildHist(
437437
GHistRow<double, MemoryType::on_device>* hist,
438438
bool isDense,
439439
GHistRow<double, MemoryType::on_device>* hist_buffer,
440-
const tree::HistDispatcher& dispatcher,
440+
const tree::DeviceProperties& device_prop,
441441
::sycl::event event_priv,
442442
bool force_atomic_use);
443443

plugin/sycl/common/hist_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class ParallelGHistBuilder {
124124
}
125125

126126
void Reset(size_t nblocks) {
127-
hist_device_buffer_.Resize(qu_, nblocks * nbins_ * 2);
127+
hist_device_buffer_.Resize(qu_, nblocks * nbins_);
128128
}
129129

130130
GHistRowT& GetDeviceBuffer() {
@@ -162,7 +162,7 @@ class GHistBuilder {
162162
GHistRowT<MemoryType::on_device>* HistCollection,
163163
bool isDense,
164164
GHistRowT<MemoryType::on_device>* hist_buffer,
165-
const tree::HistDispatcher& dispatcher,
165+
const tree::DeviceProperties& device_prop,
166166
::sycl::event event,
167167
bool force_atomic_use = false);
168168

0 commit comments

Comments
 (0)