12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
from dataclasses import dataclass
15
- from typing import Callable , Dict , List , Optional , Tuple , Union
15
+ from typing import Callable , Optional , Union
16
16
17
17
import numpy as np
18
18
import torch
@@ -196,17 +196,17 @@ class LightGlueKeypointMatchingOutput(ModelOutput):
196
196
keypoints : Optional [torch .FloatTensor ] = None
197
197
prune : Optional [torch .IntTensor ] = None
198
198
mask : Optional [torch .FloatTensor ] = None
199
- hidden_states : Optional [Tuple [torch .FloatTensor ]] = None
200
- attentions : Optional [Tuple [torch .FloatTensor ]] = None
199
+ hidden_states : Optional [tuple [torch .FloatTensor ]] = None
200
+ attentions : Optional [tuple [torch .FloatTensor ]] = None
201
201
202
202
203
203
class LightGlueImageProcessor (SuperGlueImageProcessor ):
204
204
def post_process_keypoint_matching (
205
205
self ,
206
206
outputs : LightGlueKeypointMatchingOutput ,
207
- target_sizes : Union [TensorType , List [ Tuple ]],
207
+ target_sizes : Union [TensorType , list [ tuple ]],
208
208
threshold : float = 0.0 ,
209
- ) -> List [ Dict [str , torch .Tensor ]]:
209
+ ) -> list [ dict [str , torch .Tensor ]]:
210
210
return super ().post_process_keypoint_matching (outputs , target_sizes , threshold )
211
211
212
212
def plot_keypoint_matching (self , images : ImageInput , keypoint_matching_output : LightGlueKeypointMatchingOutput ):
@@ -263,7 +263,7 @@ def __init__(self, config: LightGlueConfig):
263
263
264
264
def forward (
265
265
self , keypoints : torch .Tensor , output_hidden_states : Optional [bool ] = False
266
- ) -> Union [Tuple [torch .Tensor ], Tuple [torch .Tensor , torch .Tensor ]]:
266
+ ) -> Union [tuple [torch .Tensor ], tuple [torch .Tensor , torch .Tensor ]]:
267
267
projected_keypoints = self .projector (keypoints )
268
268
embeddings = projected_keypoints .repeat_interleave (2 , dim = - 1 )
269
269
cosines = torch .cos (embeddings )
@@ -277,12 +277,12 @@ class LightGlueAttention(LlamaAttention):
277
277
def forward (
278
278
self ,
279
279
hidden_states : torch .Tensor ,
280
- position_embeddings : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
280
+ position_embeddings : Optional [tuple [torch .Tensor , torch .Tensor ]] = None ,
281
281
attention_mask : Optional [torch .Tensor ] = None ,
282
282
encoder_hidden_states : Optional [torch .Tensor ] = None ,
283
283
encoder_attention_mask : Optional [torch .Tensor ] = None ,
284
284
** kwargs : Unpack [FlashAttentionKwargs ],
285
- ) -> Tuple [torch .Tensor , Optional [torch .Tensor ], Optional [Tuple [torch .Tensor ]]]:
285
+ ) -> tuple [torch .Tensor , Optional [torch .Tensor ], Optional [tuple [torch .Tensor ]]]:
286
286
input_shape = hidden_states .shape [:- 1 ]
287
287
hidden_shape = (* input_shape , - 1 , self .head_dim )
288
288
@@ -348,7 +348,7 @@ def forward(
348
348
attention_mask : torch .Tensor ,
349
349
output_hidden_states : Optional [bool ] = False ,
350
350
output_attentions : Optional [bool ] = False ,
351
- ) -> Tuple [torch .Tensor , Optional [Tuple [torch .Tensor ]], Optional [Tuple [torch .Tensor ]]]:
351
+ ) -> tuple [torch .Tensor , Optional [tuple [torch .Tensor ]], Optional [tuple [torch .Tensor ]]]:
352
352
all_hidden_states = () if output_hidden_states else None
353
353
all_attentions = () if output_attentions else None
354
354
@@ -509,7 +509,7 @@ def _init_weights(self, module: nn.Module) -> None:
509
509
module .weight .data .fill_ (1.0 )
510
510
511
511
512
- def get_matches_from_scores (scores : torch .Tensor , threshold : float ) -> Tuple [torch .Tensor , torch .Tensor ]:
512
+ def get_matches_from_scores (scores : torch .Tensor , threshold : float ) -> tuple [torch .Tensor , torch .Tensor ]:
513
513
"""obtain matches from a score matrix [Bx M+1 x N+1]"""
514
514
batch_size , _ , _ = scores .shape
515
515
# For each keypoint, get the best match
@@ -622,7 +622,7 @@ def _get_confidence_threshold(self, layer_index: int) -> float:
622
622
623
623
def _keypoint_processing (
624
624
self , descriptors : torch .Tensor , keypoints : torch .Tensor , output_hidden_states : Optional [bool ] = False
625
- ) -> Tuple [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
625
+ ) -> tuple [torch .Tensor , tuple [torch .Tensor , torch .Tensor ]]:
626
626
descriptors = descriptors .detach ().contiguous ()
627
627
projected_descriptors = self .input_projection (descriptors )
628
628
keypoint_encoding_output = self .positional_encoder (keypoints , output_hidden_states = output_hidden_states )
@@ -733,7 +733,7 @@ def _do_final_keypoint_pruning(
733
733
matches : torch .Tensor ,
734
734
matching_scores : torch .Tensor ,
735
735
num_keypoints : torch .Tensor ,
736
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
736
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
737
737
# (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
738
738
# have tensors from
739
739
batch_size , _ = indices .shape
@@ -773,7 +773,7 @@ def _match_image_pair(
773
773
mask : torch .Tensor = None ,
774
774
output_attentions : Optional [bool ] = None ,
775
775
output_hidden_states : Optional [bool ] = None ,
776
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , Tuple , Tuple ]:
776
+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , tuple , tuple ]:
777
777
all_hidden_states = () if output_hidden_states else None
778
778
all_attentions = () if output_attentions else None
779
779
@@ -949,7 +949,7 @@ def forward(
949
949
labels : Optional [torch .LongTensor ] = None ,
950
950
output_attentions : Optional [bool ] = None ,
951
951
output_hidden_states : Optional [bool ] = None ,
952
- ) -> Union [Tuple , LightGlueKeypointMatchingOutput ]:
952
+ ) -> Union [tuple , LightGlueKeypointMatchingOutput ]:
953
953
loss = None
954
954
if labels is not None :
955
955
raise ValueError ("LightGlue is not trainable, no labels should be provided." )
0 commit comments