Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: DeepGraphLearning/torchdrug
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 73920708a99a8936282ca2ccbacba9d4187060ce
Choose a base ref
..
head repository: DeepGraphLearning/torchdrug
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: e73d2c93bc0fff205eeaae2d023fd530aed4981b
Choose a head ref
Showing with 0 additions and 4 deletions.
  1. +0 −4 torchdrug/metrics/metric.py
4 changes: 0 additions & 4 deletions torchdrug/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -20,8 +20,6 @@ def area_under_roc(pred, target):
pred (Tensor): predictions of shape :math:`(n,)`
target (Tensor): binary targets of shape :math:`(n,)`
"""
if target.dtype != torch.long:
raise TypeError("Expect `target` to be torch.long, but found %s" % target.dtype)
order = pred.argsort(descending=True)
target = target[order]
hit = target.cumsum(0)
@@ -39,8 +37,6 @@ def area_under_prc(pred, target):
pred (Tensor): predictions of shape :math:`(n,)`
target (Tensor): binary targets of shape :math:`(n,)`
"""
if target.dtype != torch.long:
raise TypeError("Expect `target` to be torch.long, but found %s" % target.dtype)
order = pred.argsort(descending=True)
target = target[order]
precision = target.cumsum(0) / torch.arange(1, len(target) + 1, device=target.device)