Skip to content

Commit

Permalink
Fix the workspace allocation for the transformer kernel (#1397)
Browse files Browse the repository at this point in the history
* fix the workspace allocation for the transformer kernel

* change layer-id type & rm one unit test due to OOM
  • Loading branch information
RezaYazdaniAminabadi authored Oct 6, 2021
1 parent c6d1418 commit bc7778e
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 49 deletions.
42 changes: 21 additions & 21 deletions csrc/includes/ds_transformer_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ struct BertGemmAlgos {
template <typename T>
class BertTransformerLayer {
public:
BertTransformerLayer(int layer_id,
int batch_size,
int hidden_size,
int num_heads,
int intermediate_size,
int seq_length,
BertTransformerLayer(unsigned layer_id,
unsigned batch_size,
unsigned hidden_size,
unsigned num_heads,
unsigned intermediate_size,
unsigned seq_length,
float attn_dropout_ratio,
float hidden_output_dropout_ratio,
float layer_norm_eps,
Expand All @@ -52,7 +52,7 @@ class BertTransformerLayer {

virtual ~BertTransformerLayer();

void Forward(int bsz,
void Forward(unsigned bsz,
const T* input_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
Expand Down Expand Up @@ -80,7 +80,7 @@ class BertTransformerLayer {
T* gelu_inp_ptr,
T* ff2_inp_ptr);

void Backward(int bsz,
void Backward(unsigned bsz,
const T* grad_output_ptr,
const T* input_ptr,
const T* output_ptr,
Expand Down Expand Up @@ -128,13 +128,13 @@ class BertTransformerLayer {
T* attn_layer_norm_var,
T* attn_layer_norm_mean);

inline int GetBatchSize() const { return _batch_size; }
inline int GetNumHeads() const { return _heads; }
inline int GetSeqLength() const { return _seq_length; }
inline int GetIntermediateSize() const { return _intermediate_size; }
inline unsigned GetBatchSize() const { return _batch_size; }
inline unsigned GetNumHeads() const { return _heads; }
inline unsigned GetSeqLength() const { return _seq_length; }
inline unsigned GetIntermediateSize() const { return _intermediate_size; }

void SetSeqLength(int seq_len);
inline int GetHiddenSize() const { return _hidden_size; }
void SetSeqLength(unsigned seq_len);
inline unsigned GetHiddenSize() const { return _hidden_size; }
void SetTrainingMode(bool training);
inline bool IsTrainingMode() const { return _training; }
inline bool GeluCheckpoint() const { return _gelu_checkpoint; }
Expand All @@ -144,13 +144,13 @@ class BertTransformerLayer {
size_t getWorkspaceSize(int maxBatchSize) const;

// Params
int _layer_id;
int _batch_size;
int _hidden_size;
int _heads;
int _size_per_head;
int _intermediate_size;
int _seq_length;
unsigned _layer_id;
unsigned _batch_size;
unsigned _hidden_size;
unsigned _heads;
unsigned _size_per_head;
unsigned _intermediate_size;
unsigned _seq_length;

bool _pre_or_postLayerNorm;

Expand Down
56 changes: 28 additions & 28 deletions csrc/transformer/ds_transformer_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ const int init_seq_length = 128;
// C++ interface

template <typename T>
int get_workspace_size(int maxBatchSize,
int seq_len,
int hidden_size,
int intermediate_size,
int heads,
bool training,
bool gelu_checkpoint)
unsigned get_workspace_size(unsigned maxBatchSize,
unsigned seq_len,
unsigned hidden_size,
unsigned intermediate_size,
unsigned heads,
bool training,
bool gelu_checkpoint)
{
int workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
unsigned workSpacesize = 4 * (size_t(maxBatchSize) * seq_len * hidden_size);
if (training) {
workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size);
workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size),
Expand All @@ -46,12 +46,12 @@ int get_workspace_size(int maxBatchSize,
CHECK_CONTIGUOUS(x)

template <typename T>
BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
int batch_size,
int hidden_size,
int num_heads,
int intermediate_size,
int seq_length,
BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
unsigned batch_size,
unsigned hidden_size,
unsigned num_heads,
unsigned intermediate_size,
unsigned seq_length,
float attn_prob_dropout_ratio,
float hidden_output_dropout_ratio,
float layer_norm_eps,
Expand Down Expand Up @@ -144,7 +144,7 @@ void BertTransformerLayer<T>::Initialize()
}

template <typename T>
void BertTransformerLayer<T>::Forward(int bsz,
void BertTransformerLayer<T>::Forward(unsigned bsz,
const T* input_ptr,
const T* input_mask_ptr,
const T* attn_qkvw_ptr,
Expand Down Expand Up @@ -292,7 +292,7 @@ void BertTransformerLayer<T>::Forward(int bsz,
}

template <typename T>
void BertTransformerLayer<T>::Backward(int bsz,
void BertTransformerLayer<T>::Backward(unsigned bsz,
const T* grad_output_ptr,
const T* input_ptr,
const T* output_ptr,
Expand Down Expand Up @@ -575,7 +575,7 @@ void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_
}

template <typename T>
void BertTransformerLayer<T>::SetSeqLength(int seq_len)
void BertTransformerLayer<T>::SetSeqLength(unsigned seq_len)
{
_seq_length = seq_len;

Expand All @@ -586,11 +586,11 @@ void BertTransformerLayer<T>::SetSeqLength(int seq_len)
}

template <typename T>
int create_transformer_layer(int layer_id,
int batch_size,
int hidden_dim,
int num_heads,
int intermediate_size,
int create_transformer_layer(unsigned layer_id,
unsigned batch_size,
unsigned hidden_dim,
unsigned num_heads,
unsigned intermediate_size,
float attn_dropout_ratio,
float hidden_dropout_ratio,
float layer_norm_eps,
Expand Down Expand Up @@ -633,7 +633,7 @@ int create_transformer_layer(int layer_id,
}

template <typename T>
std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
std::vector<torch::Tensor> ds_transformer_forward(unsigned layer_id,
const torch::Tensor& input,
const torch::Tensor& input_mask,
const torch::Tensor& attn_qkvw,
Expand Down Expand Up @@ -669,7 +669,7 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
CHECK_INPUT(norm_w);
CHECK_INPUT(norm_b);

int bsz = input.size(0);
unsigned bsz = input.size(0);

const T* input_ptr = (const T*)input.data_ptr();
const T* input_mask_ptr = (const T*)input_mask.data_ptr();
Expand Down Expand Up @@ -704,7 +704,7 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);

int seq_len = layer->GetSeqLength();
unsigned seq_len = layer->GetSeqLength();
if (input.size(1) != seq_len) {
seq_len = input.size(1);
layer->SetSeqLength(seq_len);
Expand Down Expand Up @@ -818,7 +818,7 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
}

template <typename T>
std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
std::vector<torch::Tensor> ds_transformer_backward(unsigned layer_id,
const torch::Tensor& grad_output,
const torch::Tensor& output,
const torch::Tensor& inp_norm,
Expand Down Expand Up @@ -879,12 +879,12 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
CHECK_INPUT(norm_w);
CHECK_INPUT(norm_b);

int bsz = g_output.size(0);
unsigned bsz = g_output.size(0);

std::shared_ptr<BertTransformerLayer<T>> layer =
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);

int seq_len = layer->GetSeqLength();
unsigned seq_len = layer->GetSeqLength();
if (g_output.size(1) != seq_len) {
seq_len = g_output.size(1);
layer->SetSeqLength(seq_len);
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_cuda_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def run_forward(ds_config, seq_len, atol=1e-2, verbose=False, test_bsz=None):
# FP16 test cases can only run on the devices support FP16.
@pytest.mark.parametrize('batch_size, hidden_size, seq_len, heads, num_layers, is_preln, use_fp16',
[
#(8,2048,2048,32,1,True,True),
(8,160,128,2,3,True,True),
(8,160,128,2,3,False,True),
(8,1600,128,2,3,True,True),
Expand Down

0 comments on commit bc7778e

Please sign in to comment.