Skip to content

Commit 8064a2f

Browse files
committed
Merge branch 'training'
2 parents ef56fd4 + 8194c7f commit 8064a2f

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

src/models/classification_module.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,14 @@ def model_step(
100100
x, y = batch
101101
logits = self.forward(x)
102102

103-
if self.num_classes == 2:
103+
if self.num_classes == 2 and logits.shape[-1] == 1:
104+
# Binary classification with single output neuron
104105
y = y.view(-1, 1).float()
105106
loss = self.criterion(logits, y)
106-
preds = (logits > 0.5).float()
107+
preds = (torch.sigmoid(logits) > 0.5).float().squeeze()
108+
y = y.squeeze()
107109
else:
110+
# Multi-class or binary with 2 output neurons
108111
y = y.long()
109112
loss = self.criterion(logits, y)
110113
preds = torch.argmax(logits, dim=1)
@@ -206,9 +209,11 @@ def predict_step(
206209
x, y = batch
207210
logits = self.forward(x)
208211

209-
if self.num_classes == 2:
210-
preds = (logits > 0.5).float()
212+
if self.num_classes == 2 and logits.shape[-1] == 1:
213+
# Binary classification with single output neuron
214+
preds = (torch.sigmoid(logits) > 0.5).float().squeeze()
211215
else:
216+
# Multi-class or binary with 2 output neurons
212217
preds = torch.argmax(logits, dim=1)
213218

214219
return preds

0 commit comments

Comments
 (0)