diff --git a/nets/stdcnet.py b/nets/stdcnet.py index 13ece7e..b797dcc 100644 --- a/nets/stdcnet.py +++ b/nets/stdcnet.py @@ -301,4 +301,5 @@ def forward_impl(self, x): x = torch.randn(1,3,224,224) y = model(x) torch.save(model.state_dict(), 'cat.pth') - print(y.size()) + print(f'Number of output nodes: {len(y)}') + print('\n'.join([f'Output node {i}\'s size: {node.size()}' for i, node in enumerate(y)]))