Skip to content

Commit 18b2af5

Browse files
authored
20200319 Merge fbcode updates (pytorch#82)
1 parent 36539ce commit 18b2af5

21 files changed

+92
-97
lines changed

benchmarks/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
def gen_list_nested_tensor_construction():
88
tensors = [torch.rand(random.randint(500, 1500), 25600) for _ in range(20)]
99
def _algorithm():
10-
nt = torch._ListNestedTensor(tensors)
10+
torch._ListNestedTensor(tensors)
1111
return _algorithm
1212

1313
def gen_list_nested_tensor_unbind():
1414
nested_tensor = torch._ListNestedTensor([torch.rand(random.randint(500, 1500), 25600) for _ in range(20)])
1515
def _algorithm():
16-
ts = nested_tensor.unbind()
16+
nested_tensor.unbind()
1717
return _algorithm
1818

1919
if __name__ == "__main__":

benchmarks/jit_apply.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,11 @@ def gen_jit():
4242

4343
def gen_my_fun(scalar, tensor):
4444
@torch.jit.ignore
45-
def get_scalar() -> float:
45+
def get_scalar():
4646
return scalar
4747

4848
@torch.jit.ignore
49-
def get_tensor() -> torch.Tensor:
49+
def get_tensor():
5050
return tensor
5151

5252
@torch.jit.script

benchmarks/nearest_neighbors.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def print_results(results, keys, sub_clusters, print_details=False):
104104

105105
def benchmark_fn(fn, run_time = 15.0):
106106
times = []
107-
num_runs = 0
108107
fn()
109108
t = 0.0
110109
while (t < run_time):

benchmarks/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def gen_tensor():
1818

1919
def benchmark_fn(fn, run_time = 5.0, use_cprofile=False, warmup=1.0):
2020
times = []
21-
num_runs = 0
2221
t = 0.0
2322
pr = cProfile.Profile()
2423
cuda_avail = torch.cuda.is_available()

nestedtensor/csrc/py_init.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4646
py::overload_cast<c10::optional<int64_t>>(
4747
&THPNestedTensor::nested_stride))
4848
.def("__getitem__", py::overload_cast<int64_t>(&THPNestedTensor::getitem))
49+
#if (PYBIND11_VERSION_MAJOR == 2 && PYBIND11_VERSION_MINOR >= 4)
4950
.def(
5051
"__getitem__",
5152
py::overload_cast<py::slice>(&THPNestedTensor::getitem))
53+
#endif
5254
.def(
5355
"unbind",
5456
torch::wrap_pybind_function([](THPNestedTensor self, int64_t dim) {

nestedtensor/csrc/python_nested_tensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ struct THPNestedTensor {
5151
pybind11::object getitem(int64_t key) {
5252
return unbind(0)[key];
5353
}
54+
#if (PYBIND11_VERSION_MAJOR == 2 && PYBIND11_VERSION_MINOR >= 4)
5455
pybind11::object getitem(py::slice key) {
5556
py::list unbound = py::cast(unbind(0));
5657
return unbound[key];
5758
}
59+
#endif
5860
std::vector<pybind11::object> unbind(int64_t dim);
5961
THPIValueNode nested_size();
6062
THPIValueNode nested_stride();

nestedtensor/nested/functions.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
from nestedtensor import _C
1717

18+
from numbers import Number
19+
1820
orig_squeeze = torch.squeeze
1921

2022

@@ -161,7 +163,7 @@ def batch_norm(input, running_mean, running_var, weight=None, bias=None, trainin
161163
input_buffer, running_mean, running_var, weight, bias, training, momentum, eps)
162164
return nested.NestedTensor(_C._BufferNestedTensor(result.flatten(), input.nested_size()))
163165

164-
def t_batch_norm(input: torch.Tensor, running_mean: torch.Tensor, running_var: torch.Tensor, weight, bias, training, momentum, eps):
166+
def t_batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps):
165167
squeeze_after = False
166168
# TODO: Need support for BatchNorm1d and BatchNorm2d as well
167169
if input.dim() == 3:
@@ -233,7 +235,7 @@ def interpolate(input, size=None, scale_factor=None, mode='nearest',
233235
if utils.find_nested_tensor_dispatch_key(input) is None:
234236
return orig_interpolate(input, size, scale_factor, mode, align_corners)
235237

236-
def _interpolate(input: torch.Tensor, size: int, scale_factor: float, mode: str, align_corners: bool) -> torch.Tensor:
238+
def _interpolate(input, size, scale_factor, mode, align_corners):
237239
# TODO: Document this
238240
squeeze_after = False
239241
if input.dim() == 3:
@@ -269,10 +271,9 @@ def mm(*args, **kwargs):
269271
self.nested_size(), self.dim() - 1, result.size(-1))
270272
buffer_ = result.flatten()
271273
return nested.NestedTensor(
272-
_C._BufferNestedTensor(buffer_,
273-
result_nested_size))
274+
_C._BufferNestedTensor(buffer_, result_nested_size))
274275

275-
tf = utils.tensorwise()(getattr(torch.Tensor, 'mm'))
276+
tf = utils.tensorwise()(torch.Tensor.mm)
276277
return tf(*args, **kwargs)
277278

278279

nestedtensor/nested/masking.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def _merge(tensors, nested_dim):
7878
return creation.nested_tensor(inner_tensors)
7979

8080
# Get max size per each dimension from all the passed tensors.
81-
def get_max_size(obj, res=[1]):
81+
def get_max_size(obj, res=None):
82+
if res is None:
83+
res = [1]
8284
if isinstance(obj, list) or isinstance(obj, tuple):
8385
for o in obj:
8486
res = get_max_size(o, res)
@@ -138,7 +140,7 @@ def pad_nt(nt, shape):
138140
# Return a tuple of a tensor and a mask that represent the given tensor list
139141
# Returned tensor is always the same no matter what mask_dim was passed.
140142
# If mask_dim was not passed, a mask with the smallest dimensionality would be returned.
141-
# if passed mask_dim is lower than the minimal dimensionality of the mask that can represent
143+
# if passed mask_dim is lower than the minimal dimensionality of the mask that can represent
142144
# the data tensor, an error is thrown.
143145
def to_tensor_mask(nt, mask_dim):
144146
if mask_dim is not None and mask_dim > nt.dim():

nestedtensor/nested/monkey_patch.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@ def monkey_patch(NestedTensor):
55
of the torch.Tensor or torch module corresponding implementations.
66
"""
77

8-
import os
9-
DEBUG = int(os.getenv("DEBUG", 0))
10-
118
from nestedtensor.nested import codegen
129
from nestedtensor.nested import functions
1310
import torch
@@ -248,7 +245,6 @@ def new_fn(self):
248245

249246
# module.NestedTensor = NestedTensor
250247

251-
setattr(NestedTensor, '_NestedTensor__function_dispatch', function_dispatch)
252-
setattr(NestedTensor, '_NestedTensor__jit_function_dispatch',
253-
jit_function_dispatch)
254-
setattr(NestedTensor, '_NestedTensor__C_functions', C_functions)
248+
NestedTensor._NestedTensor__function_dispatch = function_dispatch
249+
NestedTensor._NestedTensor__jit_function_dispatch = jit_function_dispatch
250+
NestedTensor._NestedTensor__C_functions = C_functions

nestedtensor/nested/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def _get_tensortype_args(args, kwargs):
218218
for a in args:
219219
if is_nested_tensor(a) or torch.is_tensor(a):
220220
tt_args.append(a)
221-
for k, v in kwargs.items():
221+
for _, v in kwargs.items():
222222
if is_nested_tensor(v) or torch.is_tensor(v):
223223
tt_args.append(v)
224224
return tt_args
@@ -229,7 +229,7 @@ def _get_nestedtensor_args(args, kwargs):
229229
for a in args:
230230
if is_nested_tensor(a):
231231
nt_args.append(a)
232-
for k, v in kwargs.items():
232+
for _, v in kwargs.items():
233233
if is_nested_tensor(v):
234234
nt_args.append(v)
235235
return nt_args

0 commit comments

Comments
 (0)