1313# limitations under the License.
1414
1515import gc
16- import random
1716import warnings
17+ from collections .abc import Mapping
1818from contextlib import contextmanager
1919from typing import Optional , Union
2020
2121import numpy as np
2222import torch
23- import torch .nn as nn
24- import torch .nn .functional as F
25- from torch .nn .utils .rnn import pad_sequence
26- from transformers import TopKLogitsWarper , TopPLogitsWarper , is_torch_npu_available , is_torch_xpu_available
27-
28-
29- try :
30- from collections .abc import Mapping
31- except ImportError :
32- from collections .abc import Mapping
33-
34-
35- WANDB_PADDING = - 1
36-
37-
38- def top_k_top_p_filtering (
39- logits : torch .FloatTensor ,
40- top_k : int = 0 ,
41- top_p : float = 1.0 ,
42- filter_value : float = - float ("Inf" ),
43- min_tokens_to_keep : int = 1 ,
44- ) -> torch .FloatTensor :
45- """
46- Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
47-
48- Args:
49- logits: logits distribution shape (batch size, vocabulary size)
50- top_k (`int`, *optional*, defaults to 0):
51- If > 0, only keep the top k tokens with highest probability (top-k filtering)
52- top_p (`float`, *optional*, defaults to 1.0):
53- If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
54- filtering is described in Holtzman et al. (https://huggingface.co/papers/1904.09751)
55- min_tokens_to_keep (`int`, *optional*, defaults to 1):
56- Minimumber of tokens we keep per batch example in the output.
57-
58- From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
59- """
60-
61- if top_k > 0 :
62- logits = TopKLogitsWarper (top_k = top_k , filter_value = filter_value , min_tokens_to_keep = min_tokens_to_keep )(
63- None , logits
64- )
65-
66- if 0 <= top_p <= 1.0 :
67- logits = TopPLogitsWarper (top_p = top_p , filter_value = filter_value , min_tokens_to_keep = min_tokens_to_keep )(
68- None , logits
69- )
70-
71- return logits
23+ from transformers import is_torch_npu_available , is_torch_xpu_available
7224
7325
7426def flatten_dict (nested : dict , sep : str = "/" ) -> dict :
@@ -88,52 +40,6 @@ def recurse(nest: dict, prefix: str, into: dict) -> None:
8840 return flat
8941
9042
91- def convert_to_scalar (stats : dict ) -> dict :
92- """
93- Converts the stats from a flattened dict to single scalar dicts
94- """
95- tensorboard_stats = {}
96- for k , v in stats .items ():
97- # for tensorboard compatibility - arrays and tensors are ignored with tensorboard
98- # therefore we convert single element tensors to scalars
99- if (isinstance (v , torch .Tensor ) or isinstance (v , np .ndarray )) and (
100- len (v .shape ) == 0 or (len (v .shape ) == 1 and v .shape [0 ] == 1 )
101- ):
102- v = v .item ()
103- tensorboard_stats [k ] = v
104- return tensorboard_stats
105-
106-
107- def stack_dicts (stats_dicts : list [dict ]) -> dict :
108- """Stack the values of a dict."""
109- results = dict ()
110- for k in stats_dicts [0 ]:
111- stats_list = [torch .flatten (d [k ]) for d in stats_dicts ]
112- results [k ] = pad_sequence (stats_list , batch_first = True , padding_value = WANDB_PADDING )
113- return results
114-
115-
116- def logprobs_from_logits (logits : torch .Tensor , labels : torch .Tensor , gather : bool = True ) -> torch .Tensor :
117- """
118- See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
119- """
120- logp = F .log_softmax (logits , dim = 2 )
121-
122- if not gather :
123- return logp
124- logpy = torch .gather (logp , 2 , labels .unsqueeze (2 )).squeeze (- 1 )
125- return logpy
126-
127-
128- def whiten (values : torch .Tensor , shift_mean : bool = True ) -> torch .Tensor :
129- """Whiten values."""
130- mean , var = torch .mean (values ), torch .var (values )
131- whitened = (values - mean ) * torch .rsqrt (var + 1e-8 )
132- if not shift_mean :
133- whitened += mean
134- return whitened
135-
136-
13743def masked_mean (values : torch .Tensor , mask : torch .Tensor , axis : Optional [bool ] = None ) -> torch .Tensor :
13844 """Compute mean of tensor with a masked values."""
13945 if axis is not None :
@@ -170,73 +76,6 @@ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = T
17076 return whitened
17177
17278
173- def clip_by_value (x : torch .Tensor , tensor_min : float , tensor_max : float ) -> torch .Tensor :
174- """
175- Tensor extension to torch.clamp
176- https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
177- """
178- clipped = torch .max (torch .min (x , tensor_max ), tensor_min )
179- return clipped
180-
181-
182- def entropy_from_logits (logits : torch .Tensor ) -> torch .Tensor :
183- """Calculate entropy from logits."""
184- pd = torch .nn .functional .softmax (logits , dim = - 1 )
185- entropy = torch .logsumexp (logits , axis = - 1 ) - torch .sum (pd * logits , axis = - 1 )
186- return entropy
187-
188-
189- def stats_to_np (stats_dict : dict ) -> dict :
190- """Cast all torch.tensors in dict to numpy arrays."""
191- new_dict = dict ()
192- for k , v in stats_dict .items ():
193- if isinstance (v , torch .Tensor ):
194- new_dict [k ] = v .detach ().cpu ()
195- if new_dict [k ].dtype == torch .bfloat16 :
196- new_dict [k ] = new_dict [k ].float ()
197- new_dict [k ] = new_dict [k ].numpy ()
198- else :
199- new_dict [k ] = v
200- if np .isscalar (new_dict [k ]):
201- new_dict [k ] = float (new_dict [k ])
202- return new_dict
203-
204-
205- def respond_to_batch (
206- model : nn .Module , queries : list [torch .LongTensor ], txt_len : int = 20 , top_k : int = 0 , top_p : float = 1.0
207- ) -> torch .LongTensor :
208- """Sample text from language model."""
209- input_ids = queries
210- for _i in range (txt_len ):
211- # Get Logits
212- outputs = model (input_ids )
213- next_token_logits = outputs [0 ][:, - 1 , :]
214- next_token_logits = top_k_top_p_filtering (next_token_logits , top_k = top_k , top_p = top_p )
215- # Sample
216- probs = F .softmax (next_token_logits , dim = - 1 )
217- next_token = torch .multinomial (probs , num_samples = 1 ).squeeze (1 )
218- input_ids = torch .cat ([input_ids , next_token .unsqueeze (- 1 )], dim = - 1 )
219- return input_ids [:, - txt_len :]
220-
221-
222- def set_seed (seed : int ) -> None :
223- """
224- Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.
225-
226- Args:
227- seed (`int`): The seed to set.
228- """
229- random .seed (seed )
230- np .random .seed (seed )
231- torch .manual_seed (seed )
232- if is_torch_xpu_available ():
233- torch .xpu .manual_seed_all (seed )
234- elif is_torch_npu_available ():
235- torch .npu .manual_seed_all (seed )
236- else :
237- torch .cuda .manual_seed_all (seed )
238-
239-
24079class LengthSampler :
24180 """
24281 Samples a length
0 commit comments