Skip to content

Commit 9fbe3b1

Browse files
committed
computing accuracy for node classification
1 parent bc8c5ed commit 9fbe3b1

File tree

1 file changed

+3
-1
lines changed
  • flgo/benchmark/toolkits/graph/node_classification

1 file changed

+3
-1
lines changed

flgo/benchmark/toolkits/graph/node_classification/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,19 @@ def test(self, model, dataset, batch_size=64, num_workers=0, pin_memory=False):
192192
dataset.change_mask_for_test()
193193
loader = self.DataLoader([dataset.data], batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory)
194194
total_loss = 0
195+
total_correct = 0
195196
total_num_samples = 0
196197
for batch in loader:
197198
tdata = self.data_to_device(batch)
198199
outputs = model(tdata)
199200
loss = self.criterion(outputs[tdata.test_mask], tdata.y[tdata.test_mask])
200201
num_samples = len(tdata.x)
201202
total_loss += num_samples * loss
203+
total_correct += outputs[tdata.test_mask].max(1)[1].eq(tdata.y[tdata.test_mask]).sum().item()
202204
total_num_samples += num_samples
203205
total_loss = total_loss.item()
204206
dataset.restore_mask()
205-
return {'loss': total_loss / total_num_samples}
207+
return {'loss': total_loss / total_num_samples, 'accuracy':1.0*total_correct/total_num_samples}
206208

207209
def data_to_device(self, data):
208210
return data.to(self.device)

0 commit comments

Comments
 (0)