File tree Expand file tree Collapse file tree 1 file changed +9
-4
lines changed
Expand file tree Collapse file tree 1 file changed +9
-4
lines changed Original file line number Diff line number Diff line change @@ -100,11 +100,14 @@ def model_step(
100100 x , y = batch
101101 logits = self .forward (x )
102102
103- if self .num_classes == 2 :
103+ if self .num_classes == 2 and logits .shape [- 1 ] == 1 :
104+ # Binary classification with single output neuron
104105 y = y .view (- 1 , 1 ).float ()
105106 loss = self .criterion (logits , y )
106- preds = (logits > 0.5 ).float ()
107+ preds = (torch .sigmoid (logits ) > 0.5 ).float ().squeeze ()
108+ y = y .squeeze ()
107109 else :
110+ # Multi-class or binary with 2 output neurons
108111 y = y .long ()
109112 loss = self .criterion (logits , y )
110113 preds = torch .argmax (logits , dim = 1 )
@@ -206,9 +209,11 @@ def predict_step(
206209 x , y = batch
207210 logits = self .forward (x )
208211
209- if self .num_classes == 2 :
210- preds = (logits > 0.5 ).float ()
212+ if self .num_classes == 2 and logits .shape [- 1 ] == 1 :
213+ # Binary classification with single output neuron
214+ preds = (torch .sigmoid (logits ) > 0.5 ).float ().squeeze ()
211215 else :
216+ # Multi-class or binary with 2 output neurons
212217 preds = torch .argmax (logits , dim = 1 )
213218
214219 return preds
You can’t perform that action at this time.
0 commit comments