Skip to content

Commit aa3ed75

Browse files
UCT/CUDA: fix review comments 3
1 parent ee2c733 commit aa3ed75

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

src/uct/cuda/cuda_ipc/cuda_ipc.cuh

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,30 +93,32 @@ UCS_F_DEVICE void uct_cuda_ipc_level_sync()
9393
return;
9494
}
9595

96-
template<typename VecT>
96+
template<typename vec_t>
9797
UCS_F_DEVICE void uct_cuda_ipc_try_copy_aligned(const char* &src, char* &dst,
9898
size_t &len,
9999
unsigned warp_id,
100100
unsigned num_warps,
101101
unsigned lane_id,
102102
unsigned num_lanes)
103103
{
104-
if (!(UCT_CUDA_IPC_IS_ALIGNED_POW2((intptr_t)src, sizeof(VecT)) &&
105-
UCT_CUDA_IPC_IS_ALIGNED_POW2((intptr_t)dst, sizeof(VecT)))) {
104+
constexpr size_t vec_size = sizeof(vec_t);
105+
106+
if (!(UCT_CUDA_IPC_IS_ALIGNED_POW2((intptr_t)src, vec_size) &&
107+
UCT_CUDA_IPC_IS_ALIGNED_POW2((intptr_t)dst, vec_size))) {
106108
return;
107109
}
108110

109-
auto src_vec = reinterpret_cast<const VecT*>(src);
110-
auto dst_vec = reinterpret_cast<VecT*>(dst);
111-
constexpr unsigned lanes_unroll = UCS_DEVICE_NUM_THREADS_IN_WARP *
111+
auto src_vec = reinterpret_cast<const vec_t*>(src);
112+
auto dst_vec = reinterpret_cast<vec_t*>(dst);
113+
constexpr size_t lanes_unroll = UCS_DEVICE_NUM_THREADS_IN_WARP *
112114
UCT_CUDA_IPC_COPY_LOOP_UNROLL;
113-
size_t num_lines = (len / (lanes_unroll * sizeof(VecT))) *
115+
size_t num_lines = (len / (lanes_unroll * vec_size)) *
114116
lanes_unroll;
115-
VecT tmp[UCT_CUDA_IPC_COPY_LOOP_UNROLL];
116117

117118
for (size_t line = warp_id * lanes_unroll + lane_id % UCS_DEVICE_NUM_THREADS_IN_WARP;
118119
line < num_lines;
119120
line += num_warps * lanes_unroll) {
121+
vec_t tmp[UCT_CUDA_IPC_COPY_LOOP_UNROLL];
120122
#pragma unroll
121123
for (int i = 0; i < UCT_CUDA_IPC_COPY_LOOP_UNROLL; i++) {
122124
tmp[i] = uct_cuda_ipc_ld_global_cg(
@@ -132,15 +134,15 @@ UCS_F_DEVICE void uct_cuda_ipc_try_copy_aligned(const char* &src, char* &dst,
132134

133135
src_vec += num_lines;
134136
dst_vec += num_lines;
135-
len = len - num_lines * sizeof(VecT);
137+
len = len - num_lines * vec_size;
136138

137-
num_lines = len / sizeof(VecT);
139+
num_lines = len / vec_size;
138140
for (size_t line = lane_id; line < num_lines; line += num_lanes) {
139-
VecT v = uct_cuda_ipc_ld_global_cg(src_vec + line);
141+
vec_t v = uct_cuda_ipc_ld_global_cg(src_vec + line);
140142
uct_cuda_ipc_st_global_cg(dst_vec + line, v);
141143
}
142144

143-
len -= num_lines * sizeof(VecT);
145+
len -= num_lines * vec_size;
144146
src = reinterpret_cast<const char*>(src_vec + num_lines);
145147
dst = reinterpret_cast<char*>(dst_vec + num_lines);
146148
}
@@ -165,11 +167,12 @@ UCS_F_DEVICE void uct_cuda_ipc_copy_level(void *dst, const void *src, size_t len
165167
{
166168
auto s1 = reinterpret_cast<const char*>(src);
167169
auto d1 = reinterpret_cast<char *>(dst);
168-
unsigned int lane_id, num_lanes, warp_id, num_warps;
170+
unsigned int lane_id, num_lanes;
169171

170172
uct_cuda_ipc_get_lane<level>(lane_id, num_lanes);
171-
warp_id = lane_id / UCS_DEVICE_NUM_THREADS_IN_WARP;
172-
num_warps = num_lanes / UCS_DEVICE_NUM_THREADS_IN_WARP;
173+
174+
const unsigned warp_id = lane_id / UCS_DEVICE_NUM_THREADS_IN_WARP;
175+
const unsigned num_warps = num_lanes / UCS_DEVICE_NUM_THREADS_IN_WARP;
173176

174177
uct_cuda_ipc_try_copy_aligned<int4>(s1, d1, len, warp_id, num_warps,
175178
lane_id, num_lanes);

0 commit comments

Comments
 (0)