|
15 | 15 |
|
16 | 16 | from nestedtensor import _C |
17 | 17 |
|
| 18 | +from numbers import Number |
| 19 | + |
18 | 20 | orig_squeeze = torch.squeeze |
19 | 21 |
|
20 | 22 |
|
@@ -161,7 +163,7 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, trainin |
161 | 163 | input_buffer, running_mean, running_var, weight, bias, training, momentum, eps) |
162 | 164 | return nested.NestedTensor(_C._BufferNestedTensor(result.flatten(), input.nested_size())) |
163 | 165 |
|
164 | | - def t_batch_norm(input: torch.Tensor, running_mean: torch.Tensor, running_var: torch.Tensor, weight, bias, training, momentum, eps): |
| 166 | + def t_batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps): |
165 | 167 | squeeze_after = False |
166 | 168 | # TODO: Need support for BatchNorm1d and BatchNorm2d as well |
167 | 169 | if input.dim() == 3: |
@@ -233,7 +235,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest', |
233 | 235 | if utils.find_nested_tensor_dispatch_key(input) is None: |
234 | 236 | return orig_interpolate(input, size, scale_factor, mode, align_corners) |
235 | 237 |
|
236 | | - def _interpolate(input: torch.Tensor, size: int, scale_factor: float, mode: str, align_corners: bool) -> torch.Tensor: |
| 238 | + def _interpolate(input, size, scale_factor, mode, align_corners): |
237 | 239 | # TODO: Document this |
238 | 240 | squeeze_after = False |
239 | 241 | if input.dim() == 3: |
@@ -269,10 +271,9 @@ def mm(*args, **kwargs): |
269 | 271 | self.nested_size(), self.dim() - 1, result.size(-1)) |
270 | 272 | buffer_ = result.flatten() |
271 | 273 | return nested.NestedTensor( |
272 | | - _C._BufferNestedTensor(buffer_, |
273 | | - result_nested_size)) |
| 274 | + _C._BufferNestedTensor(buffer_, result_nested_size)) |
274 | 275 |
|
275 | | - tf = utils.tensorwise()(getattr(torch.Tensor, 'mm')) |
| 276 | + tf = utils.tensorwise()(torch.Tensor.mm) |
276 | 277 | return tf(*args, **kwargs) |
277 | 278 |
|
278 | 279 |
|
|
0 commit comments