Skip to content

Commit 4438c51

Browse files
committed
fix style
1 parent 358bd1e commit 4438c51

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

tests/test_utils.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,15 @@ def create_testing_model(architecture='fcn', num_classes=10):
5050
)
5151
elif architecture == 'rnn':
5252
return nn.Sequential(
53-
OrderedDict(
54-
[
55-
('first_layer', nn.LSTM(256, 128, 1, batch_first=True)),
56-
('extract', ExtractTensor()),
57-
('second_layer', nn.Linear(128, 64)),
58-
('third_layer', nn.Linear(64, num_classes)),
59-
],
60-
),
61-
)
53+
OrderedDict(
54+
[
55+
('first_layer', nn.LSTM(256, 128, 1, batch_first=True)),
56+
('extract', ExtractTensor()),
57+
('second_layer', nn.Linear(128, 64)),
58+
('third_layer', nn.Linear(64, num_classes)),
59+
],
60+
),
61+
)
6262
else:
6363
raise Exception(f'Unsupported architecture type: {architecture}')
6464

@@ -67,4 +67,4 @@ class ExtractTensor(nn.Module):
6767
def forward(self, x):
6868
tensor, _ = x
6969
x = x.to(torch.float32)
70-
return tensor[:, :]
70+
return tensor[:, :]

tests/tests.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def test_visualization_fcn():
5151
f"Wrong value type for key {key}",
5252
)
5353

54+
5455
def test_visualization_cnn():
5556
data = utils.create_testing_data(architecture='cnn')
5657
model = utils.create_testing_model(architecture='cnn')
@@ -114,8 +115,9 @@ def _test_bayes_prediction(mode: str, architecture='fcn'):
114115
utils.compare_values(dict, type(res), "Wrong result type")
115116
utils.compare_values(2, len(res), "Wrong dictionary length")
116117
utils.compare_values(set(["mean", "std"]), set(res.keys()), "Wrong dictionary keys")
117-
utils.compare_values(torch.Size([len(data), num_classes]), res["mean"].shape, "Wrong mean shape")
118-
utils.compare_values(torch.Size([len(data), num_classes]), res["std"].shape, "Wrong mean std")
118+
N = len(data)
119+
utils.compare_values(torch.Size([N, num_classes]), res["mean"].shape, "Wrong mean shape")
120+
utils.compare_values(torch.Size([N, num_classes]), res["std"].shape, "Wrong mean std")
119121

120122

121123
def test_basic_bayes_wrapper():

0 commit comments

Comments
 (0)