Skip to content

Commit bc7778e

Browse files
Fix the workspace allocation for the transformer kernel (#1397)
* fix the workspace allocation for the transformer kernel * change layer-id type & rm one unit test due to OOM
1 parent c6d1418 commit bc7778e

File tree

3 files changed

+50
-49
lines changed

3 files changed

+50
-49
lines changed

csrc/includes/ds_transformer_cuda.h

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ struct BertGemmAlgos {
3434
template <typename T>
3535
class BertTransformerLayer {
3636
public:
37-
BertTransformerLayer(int layer_id,
38-
int batch_size,
39-
int hidden_size,
40-
int num_heads,
41-
int intermediate_size,
42-
int seq_length,
37+
BertTransformerLayer(unsigned layer_id,
38+
unsigned batch_size,
39+
unsigned hidden_size,
40+
unsigned num_heads,
41+
unsigned intermediate_size,
42+
unsigned seq_length,
4343
float attn_dropout_ratio,
4444
float hidden_output_dropout_ratio,
4545
float layer_norm_eps,
@@ -52,7 +52,7 @@ class BertTransformerLayer {
5252

5353
virtual ~BertTransformerLayer();
5454

55-
void Forward(int bsz,
55+
void Forward(unsigned bsz,
5656
const T* input_ptr,
5757
const T* input_mask_ptr,
5858
const T* attn_qkvw_ptr,
@@ -80,7 +80,7 @@ class BertTransformerLayer {
8080
T* gelu_inp_ptr,
8181
T* ff2_inp_ptr);
8282

83-
void Backward(int bsz,
83+
void Backward(unsigned bsz,
8484
const T* grad_output_ptr,
8585
const T* input_ptr,
8686
const T* output_ptr,
@@ -128,13 +128,13 @@ class BertTransformerLayer {
128128
T* attn_layer_norm_var,
129129
T* attn_layer_norm_mean);
130130

131-
inline int GetBatchSize() const { return _batch_size; }
132-
inline int GetNumHeads() const { return _heads; }
133-
inline int GetSeqLength() const { return _seq_length; }
134-
inline int GetIntermediateSize() const { return _intermediate_size; }
131+
inline unsigned GetBatchSize() const { return _batch_size; }
132+
inline unsigned GetNumHeads() const { return _heads; }
133+
inline unsigned GetSeqLength() const { return _seq_length; }
134+
inline unsigned GetIntermediateSize() const { return _intermediate_size; }
135135

136-
void SetSeqLength(int seq_len);
137-
inline int GetHiddenSize() const { return _hidden_size; }
136+
void SetSeqLength(unsigned seq_len);
137+
inline unsigned GetHiddenSize() const { return _hidden_size; }
138138
void SetTrainingMode(bool training);
139139
inline bool IsTrainingMode() const { return _training; }
140140
inline bool GeluCheckpoint() const { return _gelu_checkpoint; }
@@ -144,13 +144,13 @@ class BertTransformerLayer {
144144
size_t getWorkspaceSize(int maxBatchSize) const;
145145

146146
// Params
147-
int _layer_id;
148-
int _batch_size;
149-
int _hidden_size;
150-
int _heads;
151-
int _size_per_head;
152-
int _intermediate_size;
153-
int _seq_length;
147+
unsigned _layer_id;
148+
unsigned _batch_size;
149+
unsigned _hidden_size;
150+
unsigned _heads;
151+
unsigned _size_per_head;
152+
unsigned _intermediate_size;
153+
unsigned _seq_length;
154154

155155
bool _pre_or_postLayerNorm;
156156

csrc/transformer/ds_transformer_cuda.cpp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@ const int init_seq_length = 128;
1919
// C++ interface
2020

2121
template <typename T>
22-
int get_workspace_size(int maxBatchSize,
23-
int seq_len,
24-
int hidden_size,
25-
int intermediate_size,
26-
int heads,
27-
bool training,
28-
bool gelu_checkpoint)
22+
unsigned get_workspace_size(unsigned maxBatchSize,
23+
unsigned seq_len,
24+
unsigned hidden_size,
25+
unsigned intermediate_size,
26+
unsigned heads,
27+
bool training,
28+
bool gelu_checkpoint)
2929
{
30-
int workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
30+
unsigned workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
3131
if (training) {
3232
workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
3333
workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
@@ -46,12 +46,12 @@ int get_workspace_size(int maxBatchSize,
4646
CHECK_CONTIGUOUS(x)
4747

4848
template <typename T>
49-
BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
50-
int batch_size,
51-
int hidden_size,
52-
int num_heads,
53-
int intermediate_size,
54-
int seq_length,
49+
BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
50+
unsigned batch_size,
51+
unsigned hidden_size,
52+
unsigned num_heads,
53+
unsigned intermediate_size,
54+
unsigned seq_length,
5555
float attn_prob_dropout_ratio,
5656
float hidden_output_dropout_ratio,
5757
float layer_norm_eps,
@@ -144,7 +144,7 @@ void BertTransformerLayer<T>::Initialize()
144144
}
145145

146146
template <typename T>
147-
void BertTransformerLayer<T>::Forward(int bsz,
147+
void BertTransformerLayer<T>::Forward(unsigned bsz,
148148
const T* input_ptr,
149149
const T* input_mask_ptr,
150150
const T* attn_qkvw_ptr,
@@ -292,7 +292,7 @@ void BertTransformerLayer<T>::Forward(int bsz,
292292
}
293293

294294
template <typename T>
295-
void BertTransformerLayer<T>::Backward(int bsz,
295+
void BertTransformerLayer<T>::Backward(unsigned bsz,
296296
const T* grad_output_ptr,
297297
const T* input_ptr,
298298
const T* output_ptr,
@@ -575,7 +575,7 @@ void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_
575575
}
576576

577577
template <typename T>
578-
void BertTransformerLayer<T>::SetSeqLength(int seq_len)
578+
void BertTransformerLayer<T>::SetSeqLength(unsigned seq_len)
579579
{
580580
_seq_length = seq_len;
581581

@@ -586,11 +586,11 @@ void BertTransformerLayer<T>::SetSeqLength(int seq_len)
586586
}
587587

588588
template <typename T>
589-
int create_transformer_layer(int layer_id,
590-
int batch_size,
591-
int hidden_dim,
592-
int num_heads,
593-
int intermediate_size,
589+
int create_transformer_layer(unsigned layer_id,
590+
unsigned batch_size,
591+
unsigned hidden_dim,
592+
unsigned num_heads,
593+
unsigned intermediate_size,
594594
float attn_dropout_ratio,
595595
float hidden_dropout_ratio,
596596
float layer_norm_eps,
@@ -633,7 +633,7 @@ int create_transformer_layer(int layer_id,
633633
}
634634

635635
template <typename T>
636-
std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
636+
std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
637637
const torch::Tensor& input,
638638
const torch::Tensor& input_mask,
639639
const torch::Tensor& attn_qkvw,
@@ -669,7 +669,7 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
669669
CHECK_INPUT(norm_w);
670670
CHECK_INPUT(norm_b);
671671

672-
int bsz = input.size(0);
672+
unsigned bsz = input.size(0);
673673

674674
const T* input_ptr = (const T*)input.data_ptr();
675675
const T* input_mask_ptr = (const T*)input_mask.data_ptr();
@@ -704,7 +704,7 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
704704
std::shared_ptr<BertTransformerLayer<T>> layer =
705705
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
706706

707-
int seq_len = layer->GetSeqLength();
707+
unsigned seq_len = layer->GetSeqLength();
708708
if (input.size(1) != seq_len) {
709709
seq_len = input.size(1);
710710
layer->SetSeqLength(seq_len);
@@ -818,7 +818,7 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
818818
}
819819

820820
template <typename T>
821-
std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
821+
std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
822822
const torch::Tensor& grad_output,
823823
const torch::Tensor& output,
824824
const torch::Tensor& inp_norm,
@@ -879,12 +879,12 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
879879
CHECK_INPUT(norm_w);
880880
CHECK_INPUT(norm_b);
881881

882-
int bsz = g_output.size(0);
882+
unsigned bsz = g_output.size(0);
883883

884884
std::shared_ptr<BertTransformerLayer<T>> layer =
885885
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
886886

887-
int seq_len = layer->GetSeqLength();
887+
unsigned seq_len = layer->GetSeqLength();
888888
if (g_output.size(1) != seq_len) {
889889
seq_len = g_output.size(1);
890890
layer->SetSeqLength(seq_len);

tests/unit/test_cuda_forward.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
199199
# FP16 test cases can only run on the devices support FP16.
200200
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
201201
[
202+
#(8,2048,2048,32,1,True,True),
202203
(8,160,128,2,3,True,True),
203204
(8,160,128,2,3,False,True),
204205
(8,1600,128,2,3,True,True),

0 commit comments

Comments
 (0)