14
14
15
15
from ...activations import ACT2FN
16
16
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa , _prepare_4d_causal_attention_mask_for_sdpa
17
+ from ...modeling_layers import GradientCheckpointingLayer
17
18
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions , BaseModelOutputWithPoolingAndCrossAttentions
18
19
from ...modeling_utils import PreTrainedModel
19
20
from ...pytorch_utils import apply_chunking_to_forward , find_pruneable_heads_and_indices , prune_linear_layer
20
- from ...utils import (
21
- add_code_sample_docstrings ,
22
- add_start_docstrings ,
23
- add_start_docstrings_to_model_forward ,
24
- get_torch_version ,
25
- logging ,
26
- )
21
+ from ...utils import auto_docstring , get_torch_version , logging
27
22
from .configuration_dummy_bert import DummyBertConfig
28
23
29
24
30
25
logger = logging .get_logger (__name__ )
31
26
32
- _CHECKPOINT_FOR_DOC = "google-dummy_bert/dummy_bert-base-uncased"
33
- _CONFIG_FOR_DOC = "DummyBertConfig"
34
-
35
27
36
28
class DummyBertEmbeddings (nn .Module ):
37
29
"""Construct the embeddings from word, position and token_type embeddings."""
@@ -432,7 +424,7 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
432
424
return hidden_states
433
425
434
426
435
- class DummyBertLayer (nn . Module ):
427
+ class DummyBertLayer (GradientCheckpointingLayer ):
436
428
def __init__ (self , config ):
437
429
super ().__init__ ()
438
430
self .chunk_size_feed_forward = config .chunk_size_feed_forward
@@ -557,27 +549,15 @@ def forward(
557
549
layer_head_mask = head_mask [i ] if head_mask is not None else None
558
550
past_key_value = past_key_values [i ] if past_key_values is not None else None
559
551
560
- if self .gradient_checkpointing and self .training :
561
- layer_outputs = self ._gradient_checkpointing_func (
562
- layer_module .__call__ ,
563
- hidden_states ,
564
- attention_mask ,
565
- layer_head_mask ,
566
- encoder_hidden_states ,
567
- encoder_attention_mask ,
568
- past_key_value ,
569
- output_attentions ,
570
- )
571
- else :
572
- layer_outputs = layer_module (
573
- hidden_states ,
574
- attention_mask ,
575
- layer_head_mask ,
576
- encoder_hidden_states ,
577
- encoder_attention_mask ,
578
- past_key_value ,
579
- output_attentions ,
580
- )
552
+ layer_outputs = layer_module (
553
+ hidden_states ,
554
+ attention_mask ,
555
+ layer_head_mask ,
556
+ encoder_hidden_states , # as a positional argument for gradient checkpointing
557
+ encoder_attention_mask = encoder_attention_mask ,
558
+ past_key_value = past_key_value ,
559
+ output_attentions = output_attentions ,
560
+ )
581
561
582
562
hidden_states = layer_outputs [0 ]
583
563
if use_cache :
@@ -739,12 +719,8 @@ def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path):
739
719
return model
740
720
741
721
722
+ @auto_docstring
742
723
class DummyBertPreTrainedModel (PreTrainedModel ):
743
- """
744
- An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
745
- models.
746
- """
747
-
748
724
config_class = DummyBertConfig
749
725
load_tf_weights = load_tf_weights_in_dummy_bert
750
726
base_model_prefix = "dummy_bert"
@@ -770,79 +746,8 @@ def _init_weights(self, module):
770
746
module .bias .data .zero_ ()
771
747
772
748
773
- DUMMY_BERT_START_DOCSTRING = r"""
774
-
775
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
776
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
777
- etc.)
778
-
779
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
780
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
781
- and behavior.
782
-
783
- Parameters:
784
- config ([`DummyBertConfig`]): Model configuration class with all the parameters of the model.
785
- Initializing with a config file does not load the weights associated with the model, only the
786
- configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
787
- """
788
-
789
- DUMMY_BERT_INPUTS_DOCSTRING = r"""
790
- Args:
791
- input_ids (`torch.LongTensor` of shape `({0})`):
792
- Indices of input sequence tokens in the vocabulary.
793
-
794
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
795
- [`PreTrainedTokenizer.__call__`] for details.
796
-
797
- [What are input IDs?](../glossary#input-ids)
798
- attention_mask (`torch.FloatTensor` of shape `({0})`or `(batch_size, sequence_length, target_length)`, *optional*):
799
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
800
-
801
- - 1 for tokens that are **not masked**,
802
- - 0 for tokens that are **masked**.
803
-
804
- [What are attention masks?](../glossary#attention-mask)
805
- token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
806
- Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
807
- 1]`:
808
-
809
- - 0 corresponds to a *sentence A* token,
810
- - 1 corresponds to a *sentence B* token.
811
-
812
- [What are token type IDs?](../glossary#token-type-ids)
813
- position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
814
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
815
- config.max_position_embeddings - 1]`.
816
-
817
- [What are position IDs?](../glossary#position-ids)
818
- head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
819
- Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
820
-
821
- - 1 indicates the head is **not masked**,
822
- - 0 indicates the head is **masked**.
823
-
824
- inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
825
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
826
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
827
- model's internal embedding lookup matrix.
828
- output_attentions (`bool`, *optional*):
829
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
830
- tensors for more detail.
831
- output_hidden_states (`bool`, *optional*):
832
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
833
- more detail.
834
- return_dict (`bool`, *optional*):
835
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
836
- """
837
-
838
-
839
- @add_start_docstrings (
840
- "The bare DummyBert Model transformer outputting raw hidden-states without any specific head on top." ,
841
- DUMMY_BERT_START_DOCSTRING ,
842
- )
843
- class DummyBertModel (DummyBertPreTrainedModel ):
844
- """
845
-
749
+ @auto_docstring (
750
+ custom_intro = """
846
751
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
847
752
cross-attention is added between the self-attention layers, following the architecture described in [Attention is
848
753
all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
@@ -852,10 +757,15 @@ class DummyBertModel(DummyBertPreTrainedModel):
852
757
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
853
758
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
854
759
"""
855
-
760
+ )
761
+ class DummyBertModel (DummyBertPreTrainedModel ):
856
762
_no_split_modules = ["DummyBertEmbeddings" , "DummyBertLayer" ]
857
763
858
764
def __init__ (self , config , add_pooling_layer = True ):
765
+ r"""
766
+ add_pooling_layer (bool, *optional*, defaults to `True`):
767
+ Whether to add a pooling layer
768
+ """
859
769
super ().__init__ (config )
860
770
self .config = config
861
771
@@ -884,12 +794,7 @@ class PreTrainedModel
884
794
for layer , heads in heads_to_prune .items ():
885
795
self .encoder .layer [layer ].attention .prune_heads (heads )
886
796
887
- @add_start_docstrings_to_model_forward (DUMMY_BERT_INPUTS_DOCSTRING .format ("batch_size, sequence_length" ))
888
- @add_code_sample_docstrings (
889
- checkpoint = _CHECKPOINT_FOR_DOC ,
890
- output_type = BaseModelOutputWithPoolingAndCrossAttentions ,
891
- config_class = _CONFIG_FOR_DOC ,
892
- )
797
+ @auto_docstring
893
798
def forward (
894
799
self ,
895
800
input_ids : Optional [torch .Tensor ] = None ,
@@ -906,26 +811,6 @@ def forward(
906
811
output_hidden_states : Optional [bool ] = None ,
907
812
return_dict : Optional [bool ] = None ,
908
813
) -> Union [tuple [torch .Tensor ], BaseModelOutputWithPoolingAndCrossAttentions ]:
909
- r"""
910
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
911
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
912
- the model is configured as a decoder.
913
- encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, target_length)`, *optional*):
914
- Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
915
- the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
916
-
917
- - 1 for tokens that are **not masked**,
918
- - 0 for tokens that are **masked**.
919
- past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
920
- Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
921
-
922
- If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
923
- don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
924
- `decoder_input_ids` of shape `(batch_size, sequence_length)`.
925
- use_cache (`bool`, *optional*):
926
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
927
- `past_key_values`).
928
- """
929
814
output_attentions = output_attentions if output_attentions is not None else self .config .output_attentions
930
815
output_hidden_states = (
931
816
output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
0 commit comments