Skip to content

Commit 8c49ea3

Browse files
authored
🏚 Remove unused components (huggingface#2480)
1 parent 88ad1a0 commit 8c49ea3

File tree

11 files changed

+24
-280
lines changed

11 files changed

+24
-280
lines changed

examples/research_projects/stack_llama/scripts/rl_training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from datasets import load_dataset
2121
from peft import LoraConfig
2222
from tqdm import tqdm
23-
from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline
23+
from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline, set_seed
2424

25-
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, set_seed
25+
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
2626
from trl.core import LengthSampler
2727

2828

examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@
2525
HfArgumentParser,
2626
RobertaForSequenceClassification,
2727
RobertaTokenizer,
28+
set_seed,
2829
)
2930

30-
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model, set_seed
31+
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model
3132
from trl.core import LengthSampler
3233

3334

tests/test_core.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from trl.core import masked_mean, masked_var, masked_whiten, whiten
19+
from trl.core import masked_mean, masked_var, masked_whiten
2020

2121

2222
class CoreTester(unittest.TestCase):
@@ -36,6 +36,10 @@ def test_masked_var(self):
3636
self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask))
3737

3838
def test_masked_whiten(self):
39+
def whiten(values: torch.Tensor) -> torch.Tensor:
40+
mean, var = torch.mean(values), torch.var(values)
41+
return (values - mean) * torch.rsqrt(var + 1e-8)
42+
3943
whiten_unmasked = whiten(self.test_input_unmasked)
4044
whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3]
4145
diffs = (whiten_unmasked - whiten_masked).sum()

trl/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
_import_structure = {
2323
"scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"],
24-
"core": ["set_seed"],
2524
"data_utils": [
2625
"apply_chat_template",
2726
"extract_prompt",
@@ -115,7 +114,6 @@
115114
_import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"])
116115

117116
if TYPE_CHECKING:
118-
from .core import set_seed
119117
from .data_utils import (
120118
apply_chat_template,
121119
extract_prompt,

trl/core.py

Lines changed: 2 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -13,62 +13,14 @@
1313
# limitations under the License.
1414

1515
import gc
16-
import random
1716
import warnings
17+
from collections.abc import Mapping
1818
from contextlib import contextmanager
1919
from typing import Optional, Union
2020

2121
import numpy as np
2222
import torch
23-
import torch.nn as nn
24-
import torch.nn.functional as F
25-
from torch.nn.utils.rnn import pad_sequence
26-
from transformers import TopKLogitsWarper, TopPLogitsWarper, is_torch_npu_available, is_torch_xpu_available
27-
28-
29-
try:
30-
from collections.abc import Mapping
31-
except ImportError:
32-
from collections.abc import Mapping
33-
34-
35-
WANDB_PADDING = -1
36-
37-
38-
def top_k_top_p_filtering(
39-
logits: torch.FloatTensor,
40-
top_k: int = 0,
41-
top_p: float = 1.0,
42-
filter_value: float = -float("Inf"),
43-
min_tokens_to_keep: int = 1,
44-
) -> torch.FloatTensor:
45-
"""
46-
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
47-
48-
Args:
49-
logits: logits distribution shape (batch size, vocabulary size)
50-
top_k (`int`, *optional*, defaults to 0):
51-
If > 0, only keep the top k tokens with highest probability (top-k filtering)
52-
top_p (`float`, *optional*, defaults to 1.0):
53-
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
54-
filtering is described in Holtzman et al. (https://huggingface.co/papers/1904.09751)
55-
min_tokens_to_keep (`int`, *optional*, defaults to 1):
56-
Minimumber of tokens we keep per batch example in the output.
57-
58-
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
59-
"""
60-
61-
if top_k > 0:
62-
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
63-
None, logits
64-
)
65-
66-
if 0 <= top_p <= 1.0:
67-
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
68-
None, logits
69-
)
70-
71-
return logits
23+
from transformers import is_torch_npu_available, is_torch_xpu_available
7224

7325

7426
def flatten_dict(nested: dict, sep: str = "/") -> dict:
@@ -88,52 +40,6 @@ def recurse(nest: dict, prefix: str, into: dict) -> None:
8840
return flat
8941

9042

91-
def convert_to_scalar(stats: dict) -> dict:
92-
"""
93-
Converts the stats from a flattened dict to single scalar dicts
94-
"""
95-
tensorboard_stats = {}
96-
for k, v in stats.items():
97-
# for tensorboard compatibility - arrays and tensors are ignored with tensorboard
98-
# therefore we convert single element tensors to scalars
99-
if (isinstance(v, torch.Tensor) or isinstance(v, np.ndarray)) and (
100-
len(v.shape) == 0 or (len(v.shape) == 1 and v.shape[0] == 1)
101-
):
102-
v = v.item()
103-
tensorboard_stats[k] = v
104-
return tensorboard_stats
105-
106-
107-
def stack_dicts(stats_dicts: list[dict]) -> dict:
108-
"""Stack the values of a dict."""
109-
results = dict()
110-
for k in stats_dicts[0]:
111-
stats_list = [torch.flatten(d[k]) for d in stats_dicts]
112-
results[k] = pad_sequence(stats_list, batch_first=True, padding_value=WANDB_PADDING)
113-
return results
114-
115-
116-
def logprobs_from_logits(logits: torch.Tensor, labels: torch.Tensor, gather: bool = True) -> torch.Tensor:
117-
"""
118-
See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591
119-
"""
120-
logp = F.log_softmax(logits, dim=2)
121-
122-
if not gather:
123-
return logp
124-
logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
125-
return logpy
126-
127-
128-
def whiten(values: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
129-
"""Whiten values."""
130-
mean, var = torch.mean(values), torch.var(values)
131-
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
132-
if not shift_mean:
133-
whitened += mean
134-
return whitened
135-
136-
13743
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
13844
"""Compute mean of tensor with a masked values."""
13945
if axis is not None:
@@ -170,73 +76,6 @@ def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = T
17076
return whitened
17177

17278

173-
def clip_by_value(x: torch.Tensor, tensor_min: float, tensor_max: float) -> torch.Tensor:
174-
"""
175-
Tensor extension to torch.clamp
176-
https://github.com/pytorch/pytorch/issues/2793#issuecomment-428784713
177-
"""
178-
clipped = torch.max(torch.min(x, tensor_max), tensor_min)
179-
return clipped
180-
181-
182-
def entropy_from_logits(logits: torch.Tensor) -> torch.Tensor:
183-
"""Calculate entropy from logits."""
184-
pd = torch.nn.functional.softmax(logits, dim=-1)
185-
entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
186-
return entropy
187-
188-
189-
def stats_to_np(stats_dict: dict) -> dict:
190-
"""Cast all torch.tensors in dict to numpy arrays."""
191-
new_dict = dict()
192-
for k, v in stats_dict.items():
193-
if isinstance(v, torch.Tensor):
194-
new_dict[k] = v.detach().cpu()
195-
if new_dict[k].dtype == torch.bfloat16:
196-
new_dict[k] = new_dict[k].float()
197-
new_dict[k] = new_dict[k].numpy()
198-
else:
199-
new_dict[k] = v
200-
if np.isscalar(new_dict[k]):
201-
new_dict[k] = float(new_dict[k])
202-
return new_dict
203-
204-
205-
def respond_to_batch(
206-
model: nn.Module, queries: list[torch.LongTensor], txt_len: int = 20, top_k: int = 0, top_p: float = 1.0
207-
) -> torch.LongTensor:
208-
"""Sample text from language model."""
209-
input_ids = queries
210-
for _i in range(txt_len):
211-
# Get Logits
212-
outputs = model(input_ids)
213-
next_token_logits = outputs[0][:, -1, :]
214-
next_token_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
215-
# Sample
216-
probs = F.softmax(next_token_logits, dim=-1)
217-
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
218-
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
219-
return input_ids[:, -txt_len:]
220-
221-
222-
def set_seed(seed: int) -> None:
223-
"""
224-
Helper function for reproducible behavior to set the seed in `random`, `numpy`, and `torch`.
225-
226-
Args:
227-
seed (`int`): The seed to set.
228-
"""
229-
random.seed(seed)
230-
np.random.seed(seed)
231-
torch.manual_seed(seed)
232-
if is_torch_xpu_available():
233-
torch.xpu.manual_seed_all(seed)
234-
elif is_torch_npu_available():
235-
torch.npu.manual_seed_all(seed)
236-
else:
237-
torch.cuda.manual_seed_all(seed)
238-
239-
24079
class LengthSampler:
24180
"""
24281
Samples a length

trl/extras/best_of_n_sampler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
from typing import Any, Callable, Optional, Union
1616

1717
import torch
18-
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
18+
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, set_seed
1919

20-
from ..core import set_seed
2120
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper
2221

2322

trl/trainer/__init__.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# There is a circular import in the PPOTrainer if we let isort sort these
1615
from typing import TYPE_CHECKING
1716

1817
from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available
@@ -21,7 +20,6 @@
2120
_import_structure = {
2221
"alignprop_config": ["AlignPropConfig"],
2322
"alignprop_trainer": ["AlignPropTrainer"],
24-
"base": ["BaseTrainer"],
2523
"bco_config": ["BCOConfig"],
2624
"bco_trainer": ["BCOTrainer"],
2725
"callbacks": [
@@ -41,8 +39,8 @@
4139
"iterative_sft_trainer": ["IterativeSFTTrainer"],
4240
"judges": [
4341
"AllTrueJudge",
44-
"BaseJudge",
4542
"BaseBinaryJudge",
43+
"BaseJudge",
4644
"BasePairwiseJudge",
4745
"BaseRankJudge",
4846
"HfPairwiseJudge",
@@ -60,23 +58,21 @@
6058
"orpo_trainer": ["ORPOTrainer"],
6159
"ppo_config": ["PPOConfig"],
6260
"ppo_trainer": ["PPOTrainer"],
63-
"ppov2_config": ["PPOv2Config"],
64-
"ppov2_trainer": ["PPOv2Trainer"],
6561
"prm_config": ["PRMConfig"],
6662
"prm_trainer": ["PRMTrainer"],
6763
"reward_config": ["RewardConfig"],
68-
"reward_trainer": ["RewardTrainer", "compute_accuracy"],
64+
"reward_trainer": ["RewardTrainer"],
6965
"rloo_config": ["RLOOConfig"],
7066
"rloo_trainer": ["RLOOTrainer"],
7167
"sft_config": ["SFTConfig"],
7268
"sft_trainer": ["SFTTrainer"],
7369
"utils": [
74-
"AdaptiveKLController",
7570
"ConstantLengthDataset",
7671
"DataCollatorForCompletionOnlyLM",
77-
"FixedKLController",
7872
"RunningMoments",
73+
"compute_accuracy",
7974
"disable_dropout_in_model",
75+
"empty_cache",
8076
"peft_module_casting_to_bf16",
8177
],
8278
"xpo_config": ["XPOConfig"],
@@ -93,7 +89,6 @@
9389
if TYPE_CHECKING:
9490
from .alignprop_config import AlignPropConfig
9591
from .alignprop_trainer import AlignPropTrainer
96-
from .base import BaseTrainer
9792
from .bco_config import BCOConfig
9893
from .bco_trainer import BCOTrainer
9994
from .callbacks import (
@@ -135,17 +130,16 @@
135130
from .prm_config import PRMConfig
136131
from .prm_trainer import PRMTrainer
137132
from .reward_config import RewardConfig
138-
from .reward_trainer import RewardTrainer, compute_accuracy
133+
from .reward_trainer import RewardTrainer
139134
from .rloo_config import RLOOConfig
140135
from .rloo_trainer import RLOOTrainer
141136
from .sft_config import SFTConfig
142137
from .sft_trainer import SFTTrainer
143138
from .utils import (
144-
AdaptiveKLController,
145139
ConstantLengthDataset,
146140
DataCollatorForCompletionOnlyLM,
147-
FixedKLController,
148141
RunningMoments,
142+
compute_accuracy,
149143
disable_dropout_in_model,
150144
empty_cache,
151145
peft_module_casting_to_bf16,

trl/trainer/alignprop_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,11 @@
2222
from accelerate import Accelerator
2323
from accelerate.logging import get_logger
2424
from accelerate.utils import ProjectConfiguration, set_seed
25+
from huggingface_hub import PyTorchModelHubMixin
2526
from transformers import is_wandb_available
2627

2728
from ..models import DDPOStableDiffusionPipeline
28-
from . import AlignPropConfig, BaseTrainer
29+
from .alignprop_config import AlignPropConfig
2930
from .utils import generate_model_card, get_comet_experiment_url
3031

3132

@@ -35,7 +36,7 @@
3536
logger = get_logger(__name__)
3637

3738

38-
class AlignPropTrainer(BaseTrainer):
39+
class AlignPropTrainer(PyTorchModelHubMixin):
3940
"""
4041
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
4142
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/

0 commit comments

Comments
 (0)