We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent c6394f8 commit 9bf8050Copy full SHA for 9bf8050
repe/rep_reading_pipeline.py
@@ -22,8 +22,9 @@ def _get_hidden_states(
22
hidden_states_layers = {}
23
for layer in hidden_layers:
24
hidden_states = outputs['hidden_states'][layer]
25
- hidden_states = hidden_states[:, rep_token, :]
26
- # hidden_states_layers[layer] = hidden_states.cpu().to(dtype=torch.float32).detach().numpy()
+ hidden_states = hidden_states[:, rep_token, :].detach()
+ if hidden_states.dtype == torch.bfloat16:
27
+ hidden_states = hidden_states.float()
28
hidden_states_layers[layer] = hidden_states.detach()
29
30
return hidden_states_layers
0 commit comments