@@ -19,15 +19,15 @@ const int init_seq_length = 128;
19
19
// C++ interface
20
20
21
21
template <typename T>
22
- int get_workspace_size (int maxBatchSize,
23
- int seq_len,
24
- int hidden_size,
25
- int intermediate_size,
26
- int heads,
27
- bool training,
28
- bool gelu_checkpoint)
22
+ unsigned get_workspace_size (unsigned maxBatchSize,
23
+ unsigned seq_len,
24
+ unsigned hidden_size,
25
+ unsigned intermediate_size,
26
+ unsigned heads,
27
+ bool training,
28
+ bool gelu_checkpoint)
29
29
{
30
- int workSpacesize = 4 * (size_t (maxBatchSize) * seq_len * hidden_size);
30
+ unsigned workSpacesize = 4 * (size_t (maxBatchSize) * seq_len * hidden_size);
31
31
if (training) {
32
32
workSpacesize += 2 * (size_t (maxBatchSize) * seq_len * hidden_size);
33
33
workSpacesize += ((std::max)((size_t (maxBatchSize) * seq_len * intermediate_size),
@@ -46,12 +46,12 @@ int get_workspace_size(int maxBatchSize,
46
46
CHECK_CONTIGUOUS (x)
47
47
48
48
template <typename T>
49
- BertTransformerLayer<T>::BertTransformerLayer(int layer_id,
50
- int batch_size,
51
- int hidden_size,
52
- int num_heads,
53
- int intermediate_size,
54
- int seq_length,
49
+ BertTransformerLayer<T>::BertTransformerLayer(unsigned layer_id,
50
+ unsigned batch_size,
51
+ unsigned hidden_size,
52
+ unsigned num_heads,
53
+ unsigned intermediate_size,
54
+ unsigned seq_length,
55
55
float attn_prob_dropout_ratio,
56
56
float hidden_output_dropout_ratio,
57
57
float layer_norm_eps,
@@ -144,7 +144,7 @@ void BertTransformerLayer<T>::Initialize()
144
144
}
145
145
146
146
template <typename T>
147
- void BertTransformerLayer<T>::Forward(int bsz,
147
+ void BertTransformerLayer<T>::Forward(unsigned bsz,
148
148
const T* input_ptr,
149
149
const T* input_mask_ptr,
150
150
const T* attn_qkvw_ptr,
@@ -292,7 +292,7 @@ void BertTransformerLayer<T>::Forward(int bsz,
292
292
}
293
293
294
294
template <typename T>
295
- void BertTransformerLayer<T>::Backward(int bsz,
295
+ void BertTransformerLayer<T>::Backward(unsigned bsz,
296
296
const T* grad_output_ptr,
297
297
const T* input_ptr,
298
298
const T* output_ptr,
@@ -575,7 +575,7 @@ void BertTransformerLayer<T>::SetIntermediateBuffers(uint8_t* attn_prob_dropout_
575
575
}
576
576
577
577
template <typename T>
578
- void BertTransformerLayer<T>::SetSeqLength(int seq_len)
578
+ void BertTransformerLayer<T>::SetSeqLength(unsigned seq_len)
579
579
{
580
580
_seq_length = seq_len;
581
581
@@ -586,11 +586,11 @@ void BertTransformerLayer<T>::SetSeqLength(int seq_len)
586
586
}
587
587
588
588
template <typename T>
589
- int create_transformer_layer (int layer_id,
590
- int batch_size,
591
- int hidden_dim,
592
- int num_heads,
593
- int intermediate_size,
589
+ int create_transformer_layer (unsigned layer_id,
590
+ unsigned batch_size,
591
+ unsigned hidden_dim,
592
+ unsigned num_heads,
593
+ unsigned intermediate_size,
594
594
float attn_dropout_ratio,
595
595
float hidden_dropout_ratio,
596
596
float layer_norm_eps,
@@ -633,7 +633,7 @@ int create_transformer_layer(int layer_id,
633
633
}
634
634
635
635
template <typename T>
636
- std::vector<torch::Tensor> ds_transformer_forward (int layer_id,
636
+ std::vector<torch::Tensor> ds_transformer_forward (unsigned layer_id,
637
637
const torch::Tensor& input,
638
638
const torch::Tensor& input_mask,
639
639
const torch::Tensor& attn_qkvw,
@@ -669,7 +669,7 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
669
669
CHECK_INPUT (norm_w);
670
670
CHECK_INPUT (norm_b);
671
671
672
- int bsz = input.size (0 );
672
+ unsigned bsz = input.size (0 );
673
673
674
674
const T* input_ptr = (const T*)input.data_ptr ();
675
675
const T* input_mask_ptr = (const T*)input_mask.data_ptr ();
@@ -704,7 +704,7 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
704
704
std::shared_ptr<BertTransformerLayer<T>> layer =
705
705
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
706
706
707
- int seq_len = layer->GetSeqLength ();
707
+ unsigned seq_len = layer->GetSeqLength ();
708
708
if (input.size (1 ) != seq_len) {
709
709
seq_len = input.size (1 );
710
710
layer->SetSeqLength (seq_len);
@@ -818,7 +818,7 @@ std::vector<torch::Tensor> ds_transformer_forward(int layer_id,
818
818
}
819
819
820
820
template <typename T>
821
- std::vector<torch::Tensor> ds_transformer_backward (int layer_id,
821
+ std::vector<torch::Tensor> ds_transformer_backward (unsigned layer_id,
822
822
const torch::Tensor& grad_output,
823
823
const torch::Tensor& output,
824
824
const torch::Tensor& inp_norm,
@@ -879,12 +879,12 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
879
879
CHECK_INPUT (norm_w);
880
880
CHECK_INPUT (norm_b);
881
881
882
- int bsz = g_output.size (0 );
882
+ unsigned bsz = g_output.size (0 );
883
883
884
884
std::shared_ptr<BertTransformerLayer<T>> layer =
885
885
std::static_pointer_cast<BertTransformerLayer<T>>(s_transformer_layers[layer_id]);
886
886
887
- int seq_len = layer->GetSeqLength ();
887
+ unsigned seq_len = layer->GetSeqLength ();
888
888
if (g_output.size (1 ) != seq_len) {
889
889
seq_len = g_output.size (1 );
890
890
layer->SetSeqLength (seq_len);
0 commit comments