Skip to content

Commit 1618001

Browse files
committed
[EM] Prevent blocking memcpy during prediction.
1 parent 843d5ee commit 1618001

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

include/xgboost/host_device_vector.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ class HostDeviceVector {
101101

102102
[[nodiscard]] bool Empty() const { return Size() == 0; }
103103
[[nodiscard]] std::size_t Size() const;
104+
[[nodiscard]] std::size_t SizeBytes() const { return this->Size() * sizeof(T); }
104105
[[nodiscard]] DeviceOrd Device() const;
105106
common::Span<T> DeviceSpan();
106107
common::Span<const T> ConstDeviceSpan() const;

src/predictor/gpu_predictor.cu

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,20 @@ class DeviceModel {
364364
int num_group;
365365
CatContainer const* cat_enc{nullptr};
366366

367+
[[nodiscard]] std::size_t MemCostBytes() const {
368+
std::size_t n_bytes = 0;
369+
n_bytes += stats.ConstDeviceSpan().size_bytes();
370+
n_bytes += tree_segments.ConstDeviceSpan().size_bytes();
371+
n_bytes += nodes.ConstDeviceSpan().size_bytes();
372+
n_bytes += tree_group.ConstDeviceSpan().size_bytes();
373+
n_bytes += split_types.ConstDeviceSpan().size_bytes();
374+
n_bytes += categories_tree_segments.ConstDeviceSpan().size_bytes();
375+
n_bytes += categories_node_segments.ConstDeviceSpan().size_bytes();
376+
n_bytes += categories.ConstDeviceSpan().size_bytes();
377+
n_bytes += sizeof(tree_beg_) + sizeof(tree_end_) + sizeof(num_group) + sizeof(cat_enc);
378+
return n_bytes;
379+
}
380+
367381
void Init(const gbm::GBTreeModel& model, bst_tree_t tree_begin, bst_tree_t tree_end,
368382
DeviceOrd device) {
369383
dh::safe_cuda(cudaSetDevice(device.ordinal));
@@ -438,6 +452,9 @@ class DeviceModel {
438452

439453
this->cat_enc = model.Cats();
440454
CHECK(this->cat_enc);
455+
456+
auto n_bytes = this->MemCostBytes(); // Pull data to device, and get the size of the model.
457+
LOG(DEBUG) << "Model size:" << common::HumanMemUnit(n_bytes);
441458
}
442459
};
443460

0 commit comments

Comments
 (0)