Skip to content

Commit 50d16db

Browse files
authored
Merge pull request #116 from ACMCMC/patch-1
Update detoxify.py
2 parents 6a9b738 + 1cda8b1 commit 50d16db

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

detoxify/detoxify.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,12 @@ def predict(self, text):
115115
self.model.eval()
116116
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(self.model.device)
117117
out = self.model(**inputs)[0]
118-
scores = torch.sigmoid(out).cpu().detach().numpy()
118+
scores = torch.sigmoid(out).cpu()
119119
results = {}
120120
for i, cla in enumerate(self.class_names):
121121
results[cla] = (
122-
scores[0][i] if isinstance(text, str) else [scores[ex_i][i].tolist() for ex_i in range(len(scores))]
122+
# If the input is a single text, squeezing will remove the dimensionality from the tensor - so `.tolist()` will return a number instead. Otherwise, we'll get the list of scores of that class.
123+
scores[:,i].squeeze().tolist()
123124
)
124125
return results
125126

0 commit comments

Comments
 (0)