diff --git a/modules/3-loss-functions-for-classification.md b/modules/3-loss-functions-for-classification.md index 5397a96..cb5f76b 100644 --- a/modules/3-loss-functions-for-classification.md +++ b/modules/3-loss-functions-for-classification.md @@ -58,7 +58,7 @@ loss1 = nn.NLLLoss() loss2 = nn.CrossEntropyLoss() C = 8 input = torch.randn(3,C,4,5) -target = torch.empty(3,4,5 dtype=torch.long).random_(0,C) +target = torch.empty(3,4,5, dtype=torch.long).random_(0,C) assert loss1(m(input),target) == loss2(input,target) ```