Skip to content

Commit 9bf8050

Browse files
authored
bfloat16 support (#51)
1 parent c6394f8 commit 9bf8050

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

repe/rep_reading_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ def _get_hidden_states(
2222
hidden_states_layers = {}
2323
for layer in hidden_layers:
2424
hidden_states = outputs['hidden_states'][layer]
25-
hidden_states = hidden_states[:, rep_token, :]
26-
# hidden_states_layers[layer] = hidden_states.cpu().to(dtype=torch.float32).detach().numpy()
25+
hidden_states = hidden_states[:, rep_token, :].detach()
26+
if hidden_states.dtype == torch.bfloat16:
27+
hidden_states = hidden_states.float()
2728
hidden_states_layers[layer] = hidden_states.detach()
2829

2930
return hidden_states_layers

0 commit comments

Comments
 (0)