Skip to content

Commit

Permalink
Merge pull request #116 from ACMCMC/patch-1
Browse files Browse the repository at this point in the history
Update detoxify.py
  • Loading branch information
laurahanu authored Jan 2, 2025
2 parents 6a9b738 + 1cda8b1 commit 50d16db
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions detoxify/detoxify.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,12 @@ def predict(self, text):
self.model.eval()
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.model.device)
out = self.model(**inputs)[0]
scores = torch.sigmoid(out).cpu().detach().numpy()
scores = torch.sigmoid(out).cpu()
results = {}
for i, cla in enumerate(self.class_names):
results[cla] = (
scores[0][i] if isinstance(text, str) else [scores[ex_i][i].tolist() for ex_i in range(len(scores))]
# If the input is a single text, squeezing will remove the dimensionality from the tensor - so `.tolist()` will return a number instead. Otherwise, we'll get the list of scores of that class.
scores[:,i].squeeze().tolist()
)
return results

Expand Down

0 comments on commit 50d16db

Please sign in to comment.