Skip to content

Commit fb6bfe6

Browse files
authored
Introduction of the raft::device_resources_snmg type (#2487)
Introduces the `raft::device_resources_snmg` type to hold all resources required for the NCCL clique. ~Answers #2459 Removed call to `raft::comms::build_comms_nccl_only` (#2465) Authors: - Victor Lafargue (https://github.com/viclafargue) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #2487
1 parent 8ea0e7e commit fb6bfe6

File tree

5 files changed

+236
-223
lines changed

5 files changed

+236
-223
lines changed

cpp/include/raft/comms/nccl_clique.hpp

-156
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
/*
2+
* Copyright (c) 2024, NVIDIA CORPORATION.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include <raft/core/device_resources.hpp>
20+
21+
#include <nccl.h>
22+
#include <omp.h>
23+
24+
#include <memory>
25+
#include <vector>
26+
27+
/**
28+
* @brief Error checking macro for NCCL runtime API functions.
29+
*
30+
* Invokes a NCCL runtime API function call, if the call does not return ncclSuccess, throws an
31+
* exception detailing the NCCL error that occurred
32+
*/
33+
#define RAFT_NCCL_TRY(call) \
34+
do { \
35+
ncclResult_t const status = (call); \
36+
if (ncclSuccess != status) { \
37+
std::string msg{}; \
38+
SET_ERROR_MSG(msg, \
39+
"NCCL error encountered at: ", \
40+
"call='%s', Reason=%d:%s", \
41+
#call, \
42+
status, \
43+
ncclGetErrorString(status)); \
44+
throw raft::logic_error(msg); \
45+
} \
46+
} while (0);
47+
48+
namespace raft {
49+
50+
/**
51+
* @brief SNMG (single-node multi-GPU) resource container object that stores a NCCL clique and all
52+
* necessary resources used for calling device functions, cuda kernels, libraries and/or NCCL
53+
* communications on each GPU. Note the `device_resources_snmg` object can also be used as a classic
54+
* `device_resources` object. The associated resources will be the ones of the GPU used during
55+
* object instantiation and a GPU switch operation will be ordered during the retrieval of said
56+
* resources.
57+
*
58+
* The `device_resources_snmg` class is intended to be used in a single process to manage several
59+
* GPUs. Please note that NCCL communications are the responsibility of the user. Blocking NCCL
60+
* calls will sometimes require the use of several threads to avoid hangs.
61+
*/
62+
class device_resources_snmg : public device_resources {
63+
public:
64+
/**
65+
* @brief Construct a SNMG resources instance with all available GPUs
66+
*/
67+
device_resources_snmg() : device_resources(), root_rank_(0)
68+
{
69+
cudaGetDevice(&main_gpu_id_);
70+
71+
int num_ranks;
72+
RAFT_CUDA_TRY(cudaGetDeviceCount(&num_ranks));
73+
device_ids_.resize(num_ranks);
74+
std::iota(device_ids_.begin(), device_ids_.end(), 0);
75+
nccl_comms_.resize(num_ranks);
76+
initialize();
77+
}
78+
79+
/**
80+
* @brief Construct a SNMG resources instance with a subset of available GPUs
81+
*
82+
* @param[in] device_ids List of device IDs to be used by the NCCL clique
83+
*/
84+
device_resources_snmg(const std::vector<int>& device_ids)
85+
: device_resources(), root_rank_(0), device_ids_(device_ids), nccl_comms_(device_ids.size())
86+
{
87+
cudaGetDevice(&main_gpu_id_);
88+
89+
initialize();
90+
}
91+
92+
/**
93+
* @brief SNMG resources instance copy constructor
94+
*
95+
* @param[in] clique A SNMG resources instance
96+
*/
97+
device_resources_snmg(const device_resources_snmg& clique)
98+
: device_resources(clique),
99+
root_rank_(clique.root_rank_),
100+
main_gpu_id_(clique.main_gpu_id_),
101+
device_ids_(clique.device_ids_),
102+
nccl_comms_(clique.nccl_comms_),
103+
device_resources_(clique.device_resources_)
104+
{
105+
}
106+
107+
device_resources_snmg(device_resources_snmg&&) = delete;
108+
device_resources_snmg& operator=(device_resources_snmg&&) = delete;
109+
110+
/**
111+
* @brief Set root rank of NCCL clique
112+
*/
113+
inline int set_root_rank(int rank) { this->root_rank_ = rank; }
114+
115+
/**
116+
* @brief Get root rank of NCCL clique
117+
*/
118+
inline int get_root_rank() const { return this->root_rank_; }
119+
120+
/**
121+
* @brief Get number of ranks in NCCL clique
122+
*/
123+
inline int get_num_ranks() const { return this->device_ids_.size(); }
124+
125+
/**
126+
* @brief Get device ID of rank in NCCL clique
127+
*/
128+
inline int get_device_id(int rank) const { return this->device_ids_[rank]; }
129+
130+
/**
131+
* @brief Get NCCL comm object of rank in NCCL clique
132+
*/
133+
inline ncclComm_t get_nccl_comm(int rank) const { return this->nccl_comms_[rank]; }
134+
135+
/**
136+
* @brief Get raft::device_resources object of rank in NCCL clique
137+
*/
138+
inline const raft::device_resources& get_device_resources(int rank) const
139+
{
140+
return this->device_resources_[rank];
141+
}
142+
143+
/**
144+
* @brief Set current device ID to root rank and return its raft::device_resources object
145+
*/
146+
inline const raft::device_resources& set_current_device_to_root_rank() const
147+
{
148+
return set_current_device_to_rank(get_root_rank());
149+
}
150+
151+
/**
152+
* @brief Set current device ID to rank and return its raft::device_resources object
153+
*/
154+
inline const raft::device_resources& set_current_device_to_rank(int rank) const
155+
{
156+
RAFT_CUDA_TRY(cudaSetDevice(get_device_id(rank)));
157+
return get_device_resources(rank);
158+
}
159+
160+
/**
161+
* @brief Set a memory pool on all GPUs of the clique
162+
*/
163+
void set_memory_pool(int percent_of_free_memory) const
164+
{
165+
for (int rank = 0; rank < get_num_ranks(); rank++) {
166+
RAFT_CUDA_TRY(cudaSetDevice(get_device_id(rank)));
167+
size_t limit =
168+
rmm::percent_of_free_device_memory(percent_of_free_memory); // check limit for each device
169+
raft::resource::set_workspace_to_pool_resource(get_device_resources(rank), limit);
170+
}
171+
cudaSetDevice(this->main_gpu_id_);
172+
}
173+
174+
bool has_resource_factory(resource::resource_type resource_type) const override
175+
{
176+
cudaSetDevice(this->main_gpu_id_);
177+
return raft::resources::has_resource_factory(resource_type);
178+
}
179+
180+
/** Destroys all held-up resources */
181+
~device_resources_snmg()
182+
{
183+
#pragma omp parallel for // necessary to avoid hangs
184+
for (int rank = 0; rank < get_num_ranks(); rank++) {
185+
RAFT_CUDA_TRY(cudaSetDevice(get_device_id(rank)));
186+
RAFT_NCCL_TRY(ncclCommDestroy(get_nccl_comm(rank)));
187+
}
188+
cudaSetDevice(this->main_gpu_id_);
189+
}
190+
191+
private:
192+
/**
193+
* @brief Initializes the NCCL clique and raft::device_resources objects
194+
*/
195+
void initialize()
196+
{
197+
RAFT_NCCL_TRY(ncclCommInitAll(nccl_comms_.data(), get_num_ranks(), device_ids_.data()));
198+
199+
for (int rank = 0; rank < get_num_ranks(); rank++) {
200+
RAFT_CUDA_TRY(cudaSetDevice(get_device_id(rank)));
201+
device_resources_.emplace_back();
202+
203+
// ideally add the ncclComm_t to the device_resources object with
204+
// raft::comms::build_comms_nccl_only
205+
}
206+
cudaSetDevice(this->main_gpu_id_);
207+
}
208+
209+
int root_rank_;
210+
int main_gpu_id_;
211+
std::vector<int> device_ids_;
212+
std::vector<ncclComm_t> nccl_comms_;
213+
std::vector<raft::device_resources> device_resources_;
214+
215+
}; // class device_resources_snmg
216+
217+
} // namespace raft

0 commit comments

Comments
 (0)