Skip to content

Commit 7acf790

Browse files
committed
Use list,tuple,dict for typing
Signed-off-by: cyy <[email protected]>
1 parent 81799d8 commit 7acf790

File tree

1,241 files changed

+15307
-15353
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,241 files changed

+15307
-15353
lines changed

examples/modular-transformers/image_processing_new_imgproc_model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# the file from the modular. If any change should be done, please apply the change to the
55
# modular_new_imgproc_model.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7-
from typing import Dict, List, Optional, Union
7+
from typing import Optional, Union
88

99
import numpy as np
1010
import torch
@@ -74,13 +74,13 @@ class ImgprocModelImageProcessor(BaseImageProcessor):
7474
def __init__(
7575
self,
7676
do_resize: bool = True,
77-
size: Optional[Dict[str, int]] = None,
77+
size: Optional[dict[str, int]] = None,
7878
resample: PILImageResampling = PILImageResampling.BICUBIC,
7979
do_rescale: bool = True,
8080
rescale_factor: Union[int, float] = 1 / 255,
8181
do_normalize: bool = True,
82-
image_mean: Optional[Union[float, List[float]]] = None,
83-
image_std: Optional[Union[float, List[float]]] = None,
82+
image_mean: Optional[Union[float, list[float]]] = None,
83+
image_std: Optional[Union[float, list[float]]] = None,
8484
do_convert_rgb: bool = True,
8585
**kwargs,
8686
) -> None:
@@ -101,7 +101,7 @@ def __init__(
101101
def resize(
102102
self,
103103
image: np.ndarray,
104-
size: Dict[str, int],
104+
size: dict[str, int],
105105
resample: PILImageResampling = PILImageResampling.BICUBIC,
106106
data_format: Optional[Union[str, ChannelDimension]] = None,
107107
input_data_format: Optional[Union[str, ChannelDimension]] = None,
@@ -151,13 +151,13 @@ def preprocess(
151151
self,
152152
images: ImageInput,
153153
do_resize: Optional[bool] = None,
154-
size: Optional[Dict[str, int]] = None,
154+
size: Optional[dict[str, int]] = None,
155155
resample: PILImageResampling = None,
156156
do_rescale: Optional[bool] = None,
157157
rescale_factor: Optional[float] = None,
158158
do_normalize: Optional[bool] = None,
159-
image_mean: Optional[Union[float, List[float]]] = None,
160-
image_std: Optional[Union[float, List[float]]] = None,
159+
image_mean: Optional[Union[float, list[float]]] = None,
160+
image_std: Optional[Union[float, list[float]]] = None,
161161
return_tensors: Optional[Union[str, TensorType]] = None,
162162
do_convert_rgb: Optional[bool] = None,
163163
data_format: ChannelDimension = ChannelDimension.FIRST,

examples/modular-transformers/modeling_add_function.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# modular_add_function.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
77
# Note that zamba does not have the `apply_rotary_pos_emb` function!
8-
from typing import Optional, Tuple
8+
from typing import Optional
99

1010
import torch
1111
from torch import nn
@@ -62,5 +62,5 @@ class TestAttention(nn.Module):
6262
def __init__(self):
6363
pass
6464

65-
def forward(self) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
65+
def forward(self) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
6666
_ = apply_rotary_pos_emb(1, 1, 1, 1)

examples/modular-transformers/modeling_dummy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# the file from the modular. If any change should be done, please apply the change to the
55
# modular_dummy.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7-
from typing import Callable, Optional, Tuple, Union
7+
from typing import Callable, Optional, Union
88

99
import torch
1010
from torch import nn
@@ -210,12 +210,12 @@ def __init__(self, config: DummyConfig, layer_idx: int):
210210
def forward(
211211
self,
212212
hidden_states: torch.Tensor,
213-
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
213+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
214214
attention_mask: Optional[torch.Tensor],
215215
past_key_value: Optional[Cache] = None,
216216
cache_position: Optional[torch.LongTensor] = None,
217217
**kwargs: Unpack[FlashAttentionKwargs],
218-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
218+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
219219
input_shape = hidden_states.shape[:-1]
220220
hidden_shape = (*input_shape, -1, self.head_dim)
221221

@@ -278,9 +278,9 @@ def forward(
278278
output_attentions: Optional[bool] = False,
279279
use_cache: Optional[bool] = False,
280280
cache_position: Optional[torch.LongTensor] = None,
281-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
281+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
282282
**kwargs: Unpack[FlashAttentionKwargs],
283-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
283+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
284284
residual = hidden_states
285285
hidden_states = self.input_layernorm(hidden_states)
286286

examples/modular-transformers/modeling_dummy_bert.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
77
import math
88
import os
9-
from typing import Optional, Tuple, Union
9+
from typing import Optional, Union
1010

1111
import torch
1212
from packaging import version
@@ -136,9 +136,9 @@ def forward(
136136
head_mask: Optional[torch.FloatTensor] = None,
137137
encoder_hidden_states: Optional[torch.FloatTensor] = None,
138138
encoder_attention_mask: Optional[torch.FloatTensor] = None,
139-
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
139+
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
140140
output_attentions: Optional[bool] = False,
141-
) -> Tuple[torch.Tensor]:
141+
) -> tuple[torch.Tensor]:
142142
mixed_query_layer = self.query(hidden_states)
143143

144144
# If this is instantiated as a cross-attention module, the keys
@@ -245,9 +245,9 @@ def forward(
245245
head_mask: Optional[torch.FloatTensor] = None,
246246
encoder_hidden_states: Optional[torch.FloatTensor] = None,
247247
encoder_attention_mask: Optional[torch.FloatTensor] = None,
248-
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
248+
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
249249
output_attentions: Optional[bool] = False,
250-
) -> Tuple[torch.Tensor]:
250+
) -> tuple[torch.Tensor]:
251251
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
252252
# TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
253253
logger.warning_once(
@@ -386,9 +386,9 @@ def forward(
386386
head_mask: Optional[torch.FloatTensor] = None,
387387
encoder_hidden_states: Optional[torch.FloatTensor] = None,
388388
encoder_attention_mask: Optional[torch.FloatTensor] = None,
389-
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
389+
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
390390
output_attentions: Optional[bool] = False,
391-
) -> Tuple[torch.Tensor]:
391+
) -> tuple[torch.Tensor]:
392392
self_outputs = self.self(
393393
hidden_states,
394394
attention_mask,
@@ -454,9 +454,9 @@ def forward(
454454
head_mask: Optional[torch.FloatTensor] = None,
455455
encoder_hidden_states: Optional[torch.FloatTensor] = None,
456456
encoder_attention_mask: Optional[torch.FloatTensor] = None,
457-
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
457+
past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None,
458458
output_attentions: Optional[bool] = False,
459-
) -> Tuple[torch.Tensor]:
459+
) -> tuple[torch.Tensor]:
460460
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
461461
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
462462
self_attention_outputs = self.attention(
@@ -532,12 +532,12 @@ def forward(
532532
head_mask: Optional[torch.FloatTensor] = None,
533533
encoder_hidden_states: Optional[torch.FloatTensor] = None,
534534
encoder_attention_mask: Optional[torch.FloatTensor] = None,
535-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
535+
past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None,
536536
use_cache: Optional[bool] = None,
537537
output_attentions: Optional[bool] = False,
538538
output_hidden_states: Optional[bool] = False,
539539
return_dict: Optional[bool] = True,
540-
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
540+
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
541541
all_hidden_states = () if output_hidden_states else None
542542
all_self_attentions = () if output_attentions else None
543543
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

examples/modular-transformers/modeling_from_uppercase_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# the file from the modular. If any change should be done, please apply the change to the
55
# modular_from_uppercase_model.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7-
from typing import Callable, Optional, Tuple, Union
7+
from typing import Callable, Optional, Union
88

99
import torch
1010
from torch import nn
@@ -71,7 +71,7 @@ def forward(
7171
attention_mask: Optional[torch.Tensor] = None,
7272
causal_attention_mask: Optional[torch.Tensor] = None,
7373
output_attentions: Optional[bool] = False,
74-
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
74+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
7575
"""Input shape: Batch x Time x Channel"""
7676

7777
batch_size, seq_length, embed_dim = hidden_states.shape
@@ -153,7 +153,7 @@ def forward(
153153
attention_mask: torch.Tensor,
154154
causal_attention_mask: torch.Tensor,
155155
output_attentions: Optional[bool] = False,
156-
) -> Tuple[torch.FloatTensor]:
156+
) -> tuple[torch.FloatTensor]:
157157
"""
158158
Args:
159159
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`

examples/modular-transformers/modeling_multimodal1.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# the file from the modular. If any change should be done, please apply the change to the
55
# modular_multimodal1.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7-
from typing import Callable, Optional, Tuple, Union
7+
from typing import Callable, Optional, Union
88

99
import torch
1010
from torch import nn
@@ -210,12 +210,12 @@ def __init__(self, config: Multimodal1TextConfig, layer_idx: int):
210210
def forward(
211211
self,
212212
hidden_states: torch.Tensor,
213-
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
213+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
214214
attention_mask: Optional[torch.Tensor],
215215
past_key_value: Optional[Cache] = None,
216216
cache_position: Optional[torch.LongTensor] = None,
217217
**kwargs: Unpack[FlashAttentionKwargs],
218-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
218+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
219219
input_shape = hidden_states.shape[:-1]
220220
hidden_shape = (*input_shape, -1, self.head_dim)
221221

@@ -278,9 +278,9 @@ def forward(
278278
output_attentions: Optional[bool] = False,
279279
use_cache: Optional[bool] = False,
280280
cache_position: Optional[torch.LongTensor] = None,
281-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
281+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
282282
**kwargs: Unpack[FlashAttentionKwargs],
283-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
283+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
284284
residual = hidden_states
285285
hidden_states = self.input_layernorm(hidden_states)
286286

examples/modular-transformers/modeling_multimodal2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# modular_multimodal2.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
77

8-
from typing import Callable, Optional, Tuple, Union
8+
from typing import Callable, Optional, Union
99

1010
import torch
1111
from torch import nn
@@ -81,7 +81,7 @@ def forward(
8181
attention_mask: Optional[torch.Tensor] = None,
8282
causal_attention_mask: Optional[torch.Tensor] = None,
8383
output_attentions: Optional[bool] = False,
84-
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
84+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
8585
"""Input shape: Batch x Time x Channel"""
8686

8787
batch_size, seq_length, embed_dim = hidden_states.shape
@@ -177,7 +177,7 @@ def forward(
177177
attention_mask: Optional[torch.Tensor] = None,
178178
causal_attention_mask: Optional[torch.Tensor] = None,
179179
output_attentions: Optional[bool] = False,
180-
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
180+
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
181181
"""Input shape: Batch x Time x Channel"""
182182

183183
batch_size, seq_length, embed_dim = hidden_states.shape
@@ -244,7 +244,7 @@ def forward(
244244
attention_mask: torch.Tensor,
245245
causal_attention_mask: torch.Tensor,
246246
output_attentions: Optional[bool] = False,
247-
) -> Tuple[torch.FloatTensor]:
247+
) -> tuple[torch.FloatTensor]:
248248
"""
249249
Args:
250250
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`

examples/modular-transformers/modeling_my_new_model2.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# the file from the modular. If any change should be done, please apply the change to the
55
# modular_my_new_model2.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7-
from typing import Callable, List, Optional, Tuple, Union
7+
from typing import Callable, Optional, Union
88

99
import torch
1010
from torch import nn
@@ -208,12 +208,12 @@ def __init__(self, config: MyNewModel2Config, layer_idx: int):
208208
def forward(
209209
self,
210210
hidden_states: torch.Tensor,
211-
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
211+
position_embeddings: tuple[torch.Tensor, torch.Tensor],
212212
attention_mask: Optional[torch.Tensor],
213213
past_key_value: Optional[Cache] = None,
214214
cache_position: Optional[torch.LongTensor] = None,
215215
**kwargs: Unpack[FlashAttentionKwargs],
216-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
216+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
217217
input_shape = hidden_states.shape[:-1]
218218
hidden_shape = (*input_shape, -1, self.head_dim)
219219

@@ -276,9 +276,9 @@ def forward(
276276
output_attentions: Optional[bool] = False,
277277
use_cache: Optional[bool] = False,
278278
cache_position: Optional[torch.LongTensor] = None,
279-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
279+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
280280
**kwargs: Unpack[FlashAttentionKwargs],
281-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
281+
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
282282
residual = hidden_states
283283
hidden_states = self.input_layernorm(hidden_states)
284284

@@ -469,7 +469,7 @@ def forward(
469469
input_ids: Optional[torch.LongTensor] = None,
470470
attention_mask: Optional[torch.Tensor] = None,
471471
position_ids: Optional[torch.LongTensor] = None,
472-
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
472+
past_key_values: Optional[Union[Cache, list[torch.FloatTensor]]] = None,
473473
inputs_embeds: Optional[torch.FloatTensor] = None,
474474
use_cache: Optional[bool] = None,
475475
output_attentions: Optional[bool] = None,

examples/modular-transformers/modeling_new_task_model.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# modular_new_task_model.py file directly. One of our CI enforces this.
66
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
77
from dataclasses import dataclass
8-
from typing import ClassVar, List, Optional, Tuple, Union
8+
from typing import ClassVar, Optional, Union
99

1010
import torch
1111
from torch import nn
@@ -88,9 +88,9 @@ class NewTaskModelCausalLMOutputWithPast(ModelOutput):
8888

8989
loss: Optional[torch.FloatTensor] = None
9090
logits: Optional[torch.FloatTensor] = None
91-
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None
92-
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
93-
attentions: Optional[Tuple[torch.FloatTensor]] = None
91+
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None
92+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
93+
attentions: Optional[tuple[torch.FloatTensor]] = None
9494
image_hidden_states: Optional[torch.FloatTensor] = None
9595

9696

@@ -249,7 +249,7 @@ def forward(
249249
pixel_values: torch.FloatTensor = None,
250250
attention_mask: Optional[torch.Tensor] = None,
251251
position_ids: Optional[torch.LongTensor] = None,
252-
past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
252+
past_key_values: Optional[Union[list[torch.FloatTensor], Cache]] = None,
253253
token_type_ids: Optional[torch.LongTensor] = None,
254254
cache_position: Optional[torch.LongTensor] = None,
255255
inputs_embeds: Optional[torch.FloatTensor] = None,
@@ -259,7 +259,7 @@ def forward(
259259
output_hidden_states: Optional[bool] = None,
260260
return_dict: Optional[bool] = None,
261261
**kwargs: Unpack[FlashAttentionKwargs],
262-
) -> Union[Tuple, NewTaskModelModelOutputWithPast]:
262+
) -> Union[tuple, NewTaskModelModelOutputWithPast]:
263263
r"""
264264
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
265265
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
@@ -442,7 +442,7 @@ def forward(
442442
output_hidden_states: Optional[bool] = None,
443443
return_dict: Optional[bool] = None,
444444
num_logits_to_keep: int = 0,
445-
) -> Union[Tuple, NewTaskModelCausalLMOutputWithPast]:
445+
) -> Union[tuple, NewTaskModelCausalLMOutputWithPast]:
446446
r"""
447447
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
448448
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,

0 commit comments

Comments
 (0)