Skip to content

Commit d058f81

Browse files
Post-PR fixes! (#38868)
* Post-PR fixes! * make fix-copies
1 parent 508a704 commit d058f81

File tree

4 files changed

+40
-41
lines changed

4 files changed

+40
-41
lines changed

src/transformers/models/lightglue/convert_lightglue_to_hf.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import gc
1616
import os
1717
import re
18-
from typing import List
1918

2019
import torch
2120
from datasets import load_dataset
@@ -90,7 +89,7 @@ def verify_model_outputs(model, device):
9089
}
9190

9291

93-
def convert_old_keys_to_new_keys(state_dict_keys: List[str]):
92+
def convert_old_keys_to_new_keys(state_dict_keys: list[str]):
9493
"""
9594
This function should be applied only once, on the concatenated keys to efficiently rename using
9695
the key mappings.

src/transformers/models/lightglue/image_processing_lightglue.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
20-
from typing import Dict, List, Optional, Tuple, Union
20+
from typing import Optional, Union
2121

2222
import numpy as np
2323
import torch
@@ -139,7 +139,7 @@ class LightGlueImageProcessor(BaseImageProcessor):
139139
do_resize (`bool`, *optional*, defaults to `True`):
140140
Controls whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden
141141
by `do_resize` in the `preprocess` method.
142-
size (`Dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`):
142+
size (`dict[str, int]` *optional*, defaults to `{"height": 480, "width": 640}`):
143143
Resolution of the output image after `resize` is applied. Only has an effect if `do_resize` is set to
144144
`True`. Can be overridden by `size` in the `preprocess` method.
145145
resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
@@ -159,7 +159,7 @@ class LightGlueImageProcessor(BaseImageProcessor):
159159
def __init__(
160160
self,
161161
do_resize: bool = True,
162-
size: Optional[Dict[str, int]] = None,
162+
size: Optional[dict[str, int]] = None,
163163
resample: PILImageResampling = PILImageResampling.BILINEAR,
164164
do_rescale: bool = True,
165165
rescale_factor: float = 1 / 255,
@@ -180,7 +180,7 @@ def __init__(
180180
def resize(
181181
self,
182182
image: np.ndarray,
183-
size: Dict[str, int],
183+
size: dict[str, int],
184184
data_format: Optional[Union[str, ChannelDimension]] = None,
185185
input_data_format: Optional[Union[str, ChannelDimension]] = None,
186186
**kwargs,
@@ -191,7 +191,7 @@ def resize(
191191
Args:
192192
image (`np.ndarray`):
193193
Image to resize.
194-
size (`Dict[str, int]`):
194+
size (`dict[str, int]`):
195195
Dictionary of the form `{"height": int, "width": int}`, specifying the size of the output image.
196196
data_format (`ChannelDimension` or `str`, *optional*):
197197
The channel dimension format of the output image. If not provided, it will be inferred from the input
@@ -220,7 +220,7 @@ def preprocess(
220220
self,
221221
images,
222222
do_resize: Optional[bool] = None,
223-
size: Optional[Dict[str, int]] = None,
223+
size: Optional[dict[str, int]] = None,
224224
resample: PILImageResampling = None,
225225
do_rescale: Optional[bool] = None,
226226
rescale_factor: Optional[float] = None,
@@ -240,7 +240,7 @@ def preprocess(
240240
`do_rescale=False`.
241241
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
242242
Whether to resize the image.
243-
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
243+
size (`dict[str, int]`, *optional*, defaults to `self.size`):
244244
Size of the output image after `resize` has been applied. If `size["shortest_edge"]` >= 384, the image
245245
is resized to `(size["shortest_edge"], size["shortest_edge"])`. Otherwise, the smaller edge of the
246246
image will be matched to `int(size["shortest_edge"]/ crop_pct)`, after which the image is cropped to
@@ -337,31 +337,31 @@ def preprocess(
337337
def post_process_keypoint_matching(
338338
self,
339339
outputs: LightGlueKeypointMatchingOutput,
340-
target_sizes: Union[TensorType, List[Tuple]],
340+
target_sizes: Union[TensorType, list[tuple]],
341341
threshold: float = 0.0,
342-
) -> List[Dict[str, torch.Tensor]]:
342+
) -> list[dict[str, torch.Tensor]]:
343343
"""
344344
Converts the raw output of [`KeypointMatchingOutput`] into lists of keypoints, scores and descriptors
345345
with coordinates absolute to the original image sizes.
346346
Args:
347347
outputs ([`KeypointMatchingOutput`]):
348348
Raw outputs of the model.
349-
target_sizes (`torch.Tensor` or `List[Tuple[Tuple[int, int]]]`, *optional*):
350-
Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`Tuple[int, int]`) containing the
349+
target_sizes (`torch.Tensor` or `list[tuple[tuple[int, int]]]`, *optional*):
350+
Tensor of shape `(batch_size, 2, 2)` or list of tuples of tuples (`tuple[int, int]`) containing the
351351
target size `(height, width)` of each image in the batch. This must be the original image size (before
352352
any processing).
353353
threshold (`float`, *optional*, defaults to 0.0):
354354
Threshold to filter out the matches with low scores.
355355
Returns:
356-
`List[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
356+
`list[Dict]`: A list of dictionaries, each dictionary containing the keypoints in the first and second image
357357
of the pair, the matching scores and the matching indices.
358358
"""
359359
if outputs.mask.shape[0] != len(target_sizes):
360360
raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the mask")
361361
if not all(len(target_size) == 2 for target_size in target_sizes):
362362
raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch")
363363

364-
if isinstance(target_sizes, List):
364+
if isinstance(target_sizes, list):
365365
image_pair_sizes = torch.tensor(target_sizes, device=outputs.mask.device)
366366
else:
367367
if target_sizes.shape[1] != 2 or target_sizes.shape[2] != 2:

src/transformers/models/lightglue/modeling_lightglue.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
# See the License for the specific language governing permissions and
1919
# limitations under the License.
2020
from dataclasses import dataclass
21-
from typing import Callable, Optional, Tuple, Union
21+
from typing import Callable, Optional, Union
2222

2323
import numpy as np
2424
import torch
@@ -74,8 +74,8 @@ class LightGlueKeypointMatchingOutput(ModelOutput):
7474
keypoints: Optional[torch.FloatTensor] = None
7575
prune: Optional[torch.IntTensor] = None
7676
mask: Optional[torch.FloatTensor] = None
77-
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
78-
attentions: Optional[Tuple[torch.FloatTensor]] = None
77+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
78+
attentions: Optional[tuple[torch.FloatTensor]] = None
7979

8080

8181
class LightGluePositionalEncoder(nn.Module):
@@ -85,7 +85,7 @@ def __init__(self, config: LightGlueConfig):
8585

8686
def forward(
8787
self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
88-
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
88+
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
8989
projected_keypoints = self.projector(keypoints)
9090
embeddings = projected_keypoints.repeat_interleave(2, dim=-1)
9191
cosines = torch.cos(embeddings)
@@ -200,12 +200,12 @@ def __init__(self, config: LightGlueConfig, layer_idx: int):
200200
def forward(
201201
self,
202202
hidden_states: torch.Tensor,
203-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
203+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
204204
attention_mask: Optional[torch.Tensor] = None,
205205
encoder_hidden_states: Optional[torch.Tensor] = None,
206206
encoder_attention_mask: Optional[torch.Tensor] = None,
207207
**kwargs: Unpack[FlashAttentionKwargs],
208-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
208+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
209209
input_shape = hidden_states.shape[:-1]
210210
hidden_shape = (*input_shape, -1, self.head_dim)
211211

@@ -274,7 +274,7 @@ def forward(
274274
attention_mask: torch.Tensor,
275275
output_hidden_states: Optional[bool] = False,
276276
output_attentions: Optional[bool] = False,
277-
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor]]]:
277+
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]:
278278
all_hidden_states = () if output_hidden_states else None
279279
all_attentions = () if output_attentions else None
280280

@@ -435,7 +435,7 @@ def _init_weights(self, module: nn.Module) -> None:
435435
module.weight.data.fill_(1.0)
436436

437437

438-
def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]:
438+
def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
439439
"""obtain matches from a score matrix [Bx M+1 x N+1]"""
440440
batch_size, _, _ = scores.shape
441441
# For each keypoint, get the best match
@@ -548,7 +548,7 @@ def _get_confidence_threshold(self, layer_index: int) -> float:
548548

549549
def _keypoint_processing(
550550
self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
551-
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
551+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
552552
descriptors = descriptors.detach().contiguous()
553553
projected_descriptors = self.input_projection(descriptors)
554554
keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states)
@@ -659,7 +659,7 @@ def _do_final_keypoint_pruning(
659659
matches: torch.Tensor,
660660
matching_scores: torch.Tensor,
661661
num_keypoints: torch.Tensor,
662-
) -> Tuple[torch.Tensor, torch.Tensor]:
662+
) -> tuple[torch.Tensor, torch.Tensor]:
663663
# (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
664664
# have tensors from
665665
batch_size, _ = indices.shape
@@ -699,7 +699,7 @@ def _match_image_pair(
699699
mask: torch.Tensor = None,
700700
output_attentions: Optional[bool] = None,
701701
output_hidden_states: Optional[bool] = None,
702-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple, Tuple]:
702+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]:
703703
all_hidden_states = () if output_hidden_states else None
704704
all_attentions = () if output_attentions else None
705705

@@ -875,7 +875,7 @@ def forward(
875875
labels: Optional[torch.LongTensor] = None,
876876
output_attentions: Optional[bool] = None,
877877
output_hidden_states: Optional[bool] = None,
878-
) -> Union[Tuple, LightGlueKeypointMatchingOutput]:
878+
) -> Union[tuple, LightGlueKeypointMatchingOutput]:
879879
loss = None
880880
if labels is not None:
881881
raise ValueError("LightGlue is not trainable, no labels should be provided.")

src/transformers/models/lightglue/modular_lightglue.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15-
from typing import Callable, Dict, List, Optional, Tuple, Union
15+
from typing import Callable, Optional, Union
1616

1717
import numpy as np
1818
import torch
@@ -196,17 +196,17 @@ class LightGlueKeypointMatchingOutput(ModelOutput):
196196
keypoints: Optional[torch.FloatTensor] = None
197197
prune: Optional[torch.IntTensor] = None
198198
mask: Optional[torch.FloatTensor] = None
199-
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
200-
attentions: Optional[Tuple[torch.FloatTensor]] = None
199+
hidden_states: Optional[tuple[torch.FloatTensor]] = None
200+
attentions: Optional[tuple[torch.FloatTensor]] = None
201201

202202

203203
class LightGlueImageProcessor(SuperGlueImageProcessor):
204204
def post_process_keypoint_matching(
205205
self,
206206
outputs: LightGlueKeypointMatchingOutput,
207-
target_sizes: Union[TensorType, List[Tuple]],
207+
target_sizes: Union[TensorType, list[tuple]],
208208
threshold: float = 0.0,
209-
) -> List[Dict[str, torch.Tensor]]:
209+
) -> list[dict[str, torch.Tensor]]:
210210
return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
211211

212212
def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
@@ -263,7 +263,7 @@ def __init__(self, config: LightGlueConfig):
263263

264264
def forward(
265265
self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
266-
) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
266+
) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
267267
projected_keypoints = self.projector(keypoints)
268268
embeddings = projected_keypoints.repeat_interleave(2, dim=-1)
269269
cosines = torch.cos(embeddings)
@@ -277,12 +277,12 @@ class LightGlueAttention(LlamaAttention):
277277
def forward(
278278
self,
279279
hidden_states: torch.Tensor,
280-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
280+
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
281281
attention_mask: Optional[torch.Tensor] = None,
282282
encoder_hidden_states: Optional[torch.Tensor] = None,
283283
encoder_attention_mask: Optional[torch.Tensor] = None,
284284
**kwargs: Unpack[FlashAttentionKwargs],
285-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
285+
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
286286
input_shape = hidden_states.shape[:-1]
287287
hidden_shape = (*input_shape, -1, self.head_dim)
288288

@@ -348,7 +348,7 @@ def forward(
348348
attention_mask: torch.Tensor,
349349
output_hidden_states: Optional[bool] = False,
350350
output_attentions: Optional[bool] = False,
351-
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]], Optional[Tuple[torch.Tensor]]]:
351+
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]:
352352
all_hidden_states = () if output_hidden_states else None
353353
all_attentions = () if output_attentions else None
354354

@@ -509,7 +509,7 @@ def _init_weights(self, module: nn.Module) -> None:
509509
module.weight.data.fill_(1.0)
510510

511511

512-
def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]:
512+
def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
513513
"""obtain matches from a score matrix [Bx M+1 x N+1]"""
514514
batch_size, _, _ = scores.shape
515515
# For each keypoint, get the best match
@@ -622,7 +622,7 @@ def _get_confidence_threshold(self, layer_index: int) -> float:
622622

623623
def _keypoint_processing(
624624
self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
625-
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
625+
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
626626
descriptors = descriptors.detach().contiguous()
627627
projected_descriptors = self.input_projection(descriptors)
628628
keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states)
@@ -733,7 +733,7 @@ def _do_final_keypoint_pruning(
733733
matches: torch.Tensor,
734734
matching_scores: torch.Tensor,
735735
num_keypoints: torch.Tensor,
736-
) -> Tuple[torch.Tensor, torch.Tensor]:
736+
) -> tuple[torch.Tensor, torch.Tensor]:
737737
# (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
738738
# have tensors from
739739
batch_size, _ = indices.shape
@@ -773,7 +773,7 @@ def _match_image_pair(
773773
mask: torch.Tensor = None,
774774
output_attentions: Optional[bool] = None,
775775
output_hidden_states: Optional[bool] = None,
776-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple, Tuple]:
776+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]:
777777
all_hidden_states = () if output_hidden_states else None
778778
all_attentions = () if output_attentions else None
779779

@@ -949,7 +949,7 @@ def forward(
949949
labels: Optional[torch.LongTensor] = None,
950950
output_attentions: Optional[bool] = None,
951951
output_hidden_states: Optional[bool] = None,
952-
) -> Union[Tuple, LightGlueKeypointMatchingOutput]:
952+
) -> Union[tuple, LightGlueKeypointMatchingOutput]:
953953
loss = None
954954
if labels is not None:
955955
raise ValueError("LightGlue is not trainable, no labels should be provided.")

0 commit comments

Comments
 (0)