diff --git a/.gitignore b/.gitignore index 1921df7b5cc..1cddc1e1ac6 100644 --- a/.gitignore +++ b/.gitignore @@ -180,4 +180,4 @@ log .DS_Store Roms -scratch/*.py +scratch/* diff --git a/README.md b/README.md index f3b766facab..ab820c19697 100644 --- a/README.md +++ b/README.md @@ -826,85 +826,52 @@ If you're using TorchRL, please refer to this BibTeX entry to cite this work: ## Installation -Create a conda environment where the packages will be installed. +### Create a new virtual environment: +```bash +python -m venv torchrl +source torchrl/bin/activate # On Windows use: venv\Scripts\activate +``` + +Or create a conda environment where the packages will be installed. ``` -conda create --name torch_rl python=3.9 -conda activate torch_rl +conda create --name torchrl python=3.9 +conda activate torchrl ``` -**PyTorch** +### Install dependencies: -Depending on the use of functorch that you want to make, you may want to +#### PyTorch + +Depending on the use of torchrl that you want to make, you may want to install the latest (nightly) PyTorch release or the latest stable version of PyTorch. See [here](https://pytorch.org/get-started/locally/) for a detailed list of commands, including `pip3` or other special installation instructions. -**Torchrl** +TorchRL offers a few pre-defined dependencies such as `"torchrl[tests]"`, `"torchrl[atari]"` etc. + +#### Torchrl You can install the **latest stable release** by using ```bash pip3 install torchrl ``` -This should work on linux, Windows 10 and OsX (Intel or Silicon chips). -On certain Windows machines (Windows 11), one should install the library locally (see below). - -For AArch64 machines, the binaries are not yet stored on PyPI so you will need to download them directly from -the [release page](https://github.com/pytorch/rl/releases/) or install the library via -``` -pip3 install git+https://github.com/pytorch/rl@v0.8.0 -``` +This should work on linux (including AArch64 machines), Windows 10 and OsX (Metal chips only). +On certain Windows machines (Windows 11), one should build the library locally. +This can be done in two ways: -The **nightly build** can be installed via -```bash -pip3 install tensordict-nightly torchrl-nightly -``` -which we currently only ship for Linux machines. -Importantly, the nightly builds require the nightly builds of PyTorch too. - -To install extra dependencies, call -```bash -pip3 install "torchrl[atari,dm_control,gym_continuous,rendering,tests,utils,marl,open_spiel,checkpointing]" -``` -or a subset of these. - -To install torchrl with the latest pytorch, use -```bash -pip3 install "torchrl[replay_buffer]" -``` -since some features in the replay buffer require PyTorch 2.7.0 or above. - -One may also desire to install the library locally. Three main reasons can motivate this: -- the nightly/stable release isn't available for one's platform (eg, Windows 11, nightlies for Apple Silicon etc.); -- contributing to the code; -- install torchrl with a previous version of PyTorch (any version >= 2.1) (note that this should also be doable via a regular install followed - by a downgrade to a previous pytorch version -- but the C++ binaries will not be available so some feature will not work, - such as prioritized replay buffers and the like.) - - **Disclaimer**: As of today, TorchRL is roughly compatible with any pytorch version >= 2.1 and installing it will not - directly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latest - PyTorch to be installed and we are working hard to loosen that requirement. - The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above. - Some features (e.g., working with nested jagged tensors) may also - be limited with older versions of pytorch. It is recommended to use the latest TorchRL with the latest PyTorch version - unless there is a strong reason not to do so. - -To install the library locally, start by cloning the repo: ```bash +# Install and build locally v0.8.1 of the library without cloning +pip3 install git+https://github.com/pytorch/rl@v0.8.1 +# Clone the library and build it locally +git clone https://github.com/pytorch/tensordict git clone https://github.com/pytorch/rl -``` -and don't forget to check out the branch or tag you want to use for the build: -```bash -git checkout v0.8.0 +pip install -e tensordict +pip install -e rl ``` -Go to the directory where you have cloned the torchrl repo and install it (after -installing `ninja`) -```bash -cd /path/to/torchrl/ -pip3 install ninja -U -python setup.py develop -``` +Note that tensordict local build requires `cmake` to be installed via [homebrew](https://brew.sh/) (MacOS) or another package manager +such as `apt`, `apt-get`, `conda` or `yum` but NOT `pip`, as well as `pip install "pybind11[global]"`. One can also build the wheels to distribute to co-workers using ```bash @@ -915,22 +882,22 @@ Your wheels will be stored there `./dist/torchrl.whl` and installable via pip install torchrl.whl ``` -**Warning**: Unfortunately, `pip3 install -e .` does not currently work. Contributions to help fix this are welcome! - -On M1 machines, this should work out-of-the-box with the nightly build of PyTorch. -If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message -`(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e'))` appears, then try - -``` -ARCHFLAGS="-arch arm64" python setup.py develop +The **nightly build** can be installed via +```bash +pip3 install tensordict-nightly torchrl-nightly ``` +which we currently only ship for Linux machines. +Importantly, the nightly builds require the nightly builds of PyTorch too. +Also, a local build of torchrl with the nightly build of tensordict may fail - install both nightlies or both local builds but do not mix them. -To run a quick sanity check, leave that directory (e.g. by executing `cd ~/`) -and try to import the library. -``` -python -c "import torchrl" -``` -This should not return any warning or error. + +**Disclaimer**: As of today, TorchRL is roughly compatible with any pytorch version >= 2.1 and installing it will not +directly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latest +PyTorch to be installed and we are working hard to loosen that requirement. +The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above. +Some features (e.g., working with nested jagged tensors) may also +be limited with older versions of pytorch. It is recommended to use the latest TorchRL with the latest PyTorch version +unless there is a strong reason not to do so. **Optional dependencies** @@ -959,43 +926,6 @@ pip3 install tensorboard pip3 install wandb ``` -**Troubleshooting** - -If a `ModuleNotFoundError: No module named ‘torchrl._torchrl` errors occurs (or -a warning indicating that the C++ binaries could not be loaded), -it means that the C++ extensions were not installed or not found. - -- One common reason might be that you are trying to import torchrl from within the - git repo location. The following code snippet should return an error if - torchrl has not been installed in `develop` mode: - ``` - cd ~/path/to/rl/repo - python -c 'from torchrl.envs.libs.gym import GymEnv' - ``` - If this is the case, consider executing torchrl from another location. -- If you're not importing torchrl from within its repo location, it could be - caused by a problem during the local installation. Check the log after the - `python setup.py develop`. One common cause is a g++/C++ version discrepancy - and/or a problem with the `ninja` library. -- If the problem persists, feel free to open an issue on the topic in the repo, - we'll make our best to help! -- On **MacOs**, we recommend installing XCode first. - With Apple Silicon M1 chips, make sure you are using the arm64-built python - (e.g. [here](https://betterprogramming.pub/how-to-install-pytorch-on-apple-m1-series-512b3ad9bc6)). - Running the following lines of code - ``` - wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py - python collect_env.py - ``` - should display - ``` - OS: macOS *** (arm64) - ``` - and not - ``` - OS: macOS **** (x86_64) - ``` - Versioning issues can cause error message of the type ```undefined symbol``` and such. For these, refer to the [versioning issues document](https://github.com/pytorch/rl/blob/main/knowledge_base/VERSIONING_ISSUES.md) for a complete explanation and proposed workarounds. diff --git a/setup.py b/setup.py index a788cfebf37..d9547777312 100644 --- a/setup.py +++ b/setup.py @@ -195,13 +195,7 @@ def _main(argv): sys.argv = [sys.argv[0]] + unknown extra_requires = { - "atari": [ - "gym", - "atari-py", - "ale-py", - "gym[accept-rom-license]", - "pygame", - ], + "atari": ["gymnasium[atari]"], "dm_control": ["dm_control"], "replay_buffer": ["torch>=2.7.0"], "gym_continuous": ["gymnasium<1.0", "mujoco"], diff --git a/sota-implementations/grpo/README.md b/sota-implementations/grpo/README.md new file mode 100644 index 00000000000..f60fafe634f --- /dev/null +++ b/sota-implementations/grpo/README.md @@ -0,0 +1,136 @@ +# GRPO: Generalized Reward-Conditioned Policy Optimization + +This is an implementation of GRPO for language models, built on top of TorchRL. + +## Overview + +GRPO is a method for training language models using reinforcement learning, with the following key features: +- Multi-GPU support with efficient device management +- Mixed precision training +- Gradient accumulation +- Automatic checkpointing +- Comprehensive logging with Weights & Biases +- Hydra configuration system + +## Installation + +1. Install dependencies: +```bash +# GSM8K deps +pip install -r sota-implementations/grpo/requirements_gsm8k.txt +# IFEval deps +pip install -r sota-implementations/grpo/requirements_ifeval.txt +``` + +2. Set required environment variables: +```bash +export VLLM_USE_V1=0 # Required for vLLM compatibility +``` + +## Hardware Requirements + +- At least 3 CUDA-capable GPUs: + - Training device(s) + - vLLM inference device + - Reference model device + +Devices can be controlled via the `training_model.devices`, `inference_model.devices` and `ref_model.devices` arguments. + +## Configuration + +The training configuration is managed through Hydra. There are two main configuration files: +- `config/grpo_gsm8k.yaml`: Default configuration for GSM8K tasks (default) +- `config/grpo_ifeval.yaml`: Configuration optimized for IFEval tasks + +## Usage + +### Basic Training + +```bash +python grpo.py +``` + +### Run with IFEval Config + +```bash +python grpo.py --config-name grpo_ifeval +``` + +### Override Config Values + +```bash +# Change dataset +python grpo.py env.dataset=ifeval + +# Modify training parameters +python grpo.py train.epochs=2 train.optimizer.lr=2e-5 + +# Change model +python grpo.py model.name=meta-llama/Llama-2-7b-hf +``` + +### Hyperparameter Sweeps + +```bash +# Learning rate sweep +python grpo.py --multirun train.optimizer.lr=1e-4,1e-5,1e-6 + +# Multiple parameters +python grpo.py --multirun \ + train.optimizer.lr=1e-4,1e-5 \ + policy.kl_coef=0.01,0.1 +``` + +## Monitoring + +Training progress is logged to Weights & Biases with the following metrics: +- Reward +- Advantage +- KL penalty +- Sequence length +- ESS (Effective Sample Size) +- Loss metrics (objective, clip fraction, etc.) +- Gradient norm + +## Checkpointing + +Checkpoints are saved every `logging.checkpoint_frequency` batches and contain: +- Model state +- Optimizer state +- Gradient scaler state (for mixed precision) +- Full configuration + +## Debugging Out-of-memory issues + +- vLLM: Reduce `inference_model.gpu_memory_utilization=FRACTION` or number of environments run + in parallel (`env.num_envs=N`). +- KL scoring: If the KL scoring is achieved on the batch of data, + reduce the number of environments (`env.num_envs=N`) run in parallel. +- Training: Reduce batch size (`train.optim_batch_size`) + +## Directory Structure + +``` +sota-implementations/grpo/ +├── config/ +│ └── grpo_gsm8k.yaml # Main configuration file +│ └── grpo_ifeval.yaml # config file for IFEval task +├── grpo.py # Training script +├── grpo_utils.py # Utility functions +└── README.md # This file +``` + +## Output Structure + +Each run creates a timestamped directory under `outputs/`: +``` +outputs/ +└── YYYY-MM-DD/ + └── HH-MM-SS/ + ├── checkpoints/ + │ └── checkpoint_*.pt + └── .hydra/ + └── config.yaml +``` + +For hyperparameter sweeps, outputs are stored under `multirun/`. diff --git a/sota-implementations/grpo/config/grpo_gsm8k.yaml b/sota-implementations/grpo/config/grpo_gsm8k.yaml new file mode 100644 index 00000000000..e0bbcfbc2fd --- /dev/null +++ b/sota-implementations/grpo/config/grpo_gsm8k.yaml @@ -0,0 +1,91 @@ +defaults: + - _self_ + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +# Environment configuration +env: + dataset: gsm8k # choices: [gsm8k, ifeval] + # Number of environments to run in parallel. This determines the batch size passed to vLLM. + # More envs consume more GPU memory. + num_envs: 4 # Reduced from 8 to save memory + # Number of times to repeat the same prompt for GRPO. This does not affect the GPU memory usage. + repeats: 16 + +# Base model configuration +model: + name: Qwen/Qwen2.5-3B + compile: false + +# Training model configuration +train_model: + gradient_checkpointing: true # Enabled for memory efficiency + devices: [0] # List of GPU devices to use for training + lora: + enabled: true # Using LoRA for memory efficiency + r: 8 # LoRA rank - controls capacity of adaptations + alpha: 16 # LoRA alpha - scales the adaptations + dropout: 0.1 # Dropout probability for LoRA layers + quantization: + enabled: false # Enable 4-bit quantization for base model + attn_implementation: sdpa # Using flash attention for memory efficiency + torch_dtype: bfloat16 + +# Inference model configuration (vLLM) +inference_model: + devices: [1] # List of GPU devices to use for inference + gpu_memory_utilization: 0.5 + temperature: 0.8 + max_tokens: 1024 + include_stop_str_in_output: true + +# Reference model configuration +ref_model: + devices: [2] # List of GPU devices to use for reference model + quantization: + enabled: false # Enable quantization for memory efficiency + gradient_checkpointing: false # Not needed for reference model + attn_implementation: + torch_dtype: bfloat16 + +# Policy configuration +policy: + kl_coef: 1e-2 + +# Training configuration +train: + epochs: 1 + # Number of dialog turns per batch. This is passed to the collector and buffer. + # More steps do not consume more GPU memory, but it does affect the inference speed in + # that in sync contexts the training node will need to wait for a batch to be completed + # before starting the next one. + steps_per_batch: 64 + # Total number of dialog turns to collect during training + total_dialog_turns: 1_000_000 + # Number of batches to run in parallel. This determines the batch size passed to the optimizer. + # More batches consume more GPU memory. + optim_batch_size: 1 + # Number of gradient accumulation steps. This determines the number of steps to run before + # updating the parameters. + gradient_accumulation_steps: 4 # Increased for gradient accumulation + # Whether to include the KL coefficient in the loss or in the environment reward. + kl_coef_in_loss: true + # Whether to use mixed precision. + mixed_precision: true # Disable mixed precision since we're not using it + optimizer: + name: AdamW + lr: 1e-5 + clip_grad_norm: 0.5 + +# Logging configuration +logging: + checkpoint_dir: checkpoints + experiment_name: null # auto-generated if null + checkpoint_frequency: 10 # save every N batches + +hydra: + run: + dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/sota-implementations/grpo/config/grpo_ifeval.yaml b/sota-implementations/grpo/config/grpo_ifeval.yaml new file mode 100644 index 00000000000..3477b5a0bee --- /dev/null +++ b/sota-implementations/grpo/config/grpo_ifeval.yaml @@ -0,0 +1,91 @@ +defaults: + - _self_ + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +# Environment configuration +env: + dataset: ifeval # choices: [gsm8k, ifeval] + # Number of environments to run in parallel. This determines the batch size passed to vLLM. + # More envs consume more GPU memory. + num_envs: 2 + # Number of times to repeat the same prompt for GRPO. This does not affect the GPU memory usage. + repeats: 16 + +# Base model configuration +model: + name: Qwen/Qwen2.5-3B + compile: false + +# Training model configuration +train_model: + gradient_checkpointing: true # Only for training model + devices: [0] # List of GPU devices to use for training + lora: + enabled: true + r: 8 # LoRA rank - controls capacity of adaptations + alpha: 16 # LoRA alpha - scales the adaptations + dropout: 0.1 # Dropout probability for LoRA layers + quantization: + enabled: false # Quantization might interfere with training + attn_implementation: sdpa # choices: [flash_attention_2, flex_attention, sdpa] + torch_dtype: bfloat16 + +# Inference model configuration (vLLM) +inference_model: + devices: [1] # List of GPU devices to use for inference + gpu_memory_utilization: 0.5 + temperature: 0.8 + max_tokens: 2048 + include_stop_str_in_output: true + +# Reference model configuration +ref_model: + devices: [2] # List of GPU devices to use for reference model + quantization: + enabled: false # Enable quantization for memory efficiency + gradient_checkpointing: false # Not needed for reference model + attn_implementation: + torch_dtype: bfloat16 + +# Policy configuration +policy: + kl_coef: 1e-2 + +# Training configuration +train: + epochs: 1 + # Number of dialog turns per batch. This is passed to the collector and buffer. + # More steps do not consume more GPU memory, but it does affect the inference speed in + # that in sync contexts the training node will need to wait for a batch to be completed + # before starting the next one. + steps_per_batch: 16 + # Total number of dialog turns to collect during training + total_dialog_turns: 1_000_000 + # Number of batches to run in parallel. This determines the batch size passed to the optimizer. + # More batches consume more GPU memory. + optim_batch_size: 1 + # Number of gradient accumulation steps. This determines the number of steps to run before + # updating the parameters. + gradient_accumulation_steps: 4 + # Whether to include the KL coefficient in the loss or in the environment reward. + kl_coef_in_loss: true + # Whether to use mixed precision. + mixed_precision: true # Disable mixed precision since we're not using it + optimizer: + name: AdamW + lr: 1e-5 + clip_grad_norm: 0.5 + +# Logging configuration +logging: + checkpoint_dir: checkpoints + experiment_name: null # auto-generated if null + checkpoint_frequency: 10 # save every N batches + +hydra: + run: + dir: outputs/${now:%Y-%m-%d}/${now:%H-%M-%S} + sweep: + dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/sota-implementations/grpo/grpo.py b/sota-implementations/grpo/grpo.py new file mode 100644 index 00000000000..6733fab9872 --- /dev/null +++ b/sota-implementations/grpo/grpo.py @@ -0,0 +1,328 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""GRPO: Generalized Reward-Conditioned Policy Optimization + +This module implements GRPO training for language models. +""" +from __future__ import annotations + +import gc +import logging +import os +from pathlib import Path + +import hydra +import torch +import tqdm +from grpo_utils import get_inference_model, get_ref_model, get_train_model +from omegaconf import DictConfig +from tensordict import set_list_to_stack, TensorDict +from torch.cuda.amp import GradScaler +from torchrl._utils import logger as torchrl_logger +from torchrl.collectors.llm import LLMCollector +from torchrl.collectors.llm.weight_update.vllm import vLLMUpdater +from torchrl.data import LazyStackStorage, ReplayBuffer, SamplerWithoutReplacement +from torchrl.envs.llm import GSM8KEnv, KLRewardTransform +from torchrl.envs.llm.datasets.ifeval import IFEvalEnv +from torchrl.objectives.llm.grpo import GRPOLoss, MCAdvantage +from torchrl.record import WandbLogger + + +def setup_environment() -> None: + """Setup required environment variables and configurations.""" + if os.getenv("VLLM_USE_V1", "1") != "0": + raise RuntimeError("VLLM_USE_V1=0 must be set in environment") + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for training") + + # Set default dtype to float32 for mixed precision training + torch.set_default_dtype(torch.float32) + torch.set_default_device("cuda:0") + set_list_to_stack(True).set() + + # Ensure CUDA is using the correct dtype + if torch.cuda.is_available(): + torch.cuda.set_device("cuda:0") + + # Set all loggers to WARNING by default + logging.getLogger().setLevel(logging.WARNING) + # But keep torchrl at INFO + logging.getLogger("torchrl").setLevel(logging.INFO) + + +@hydra.main(version_base=None, config_path="config", config_name="grpo_gsm8k") +def train(cfg: DictConfig) -> None: + """Main training loop. + + Args: + cfg: Hydra configuration object + """ + import ray + + ray.init() + + # Initialize models using config-based device allocation + torchrl_logger.info("Initializing models...") + torchrl_logger.info("Inference model...") + policy = get_inference_model(cfg) + torchrl_logger.info("Training model...") + policy_training, train_tokenizer = get_train_model(cfg) + torchrl_logger.info("Reference model...") + ref_model = get_ref_model(cfg, train_tokenizer) + torchrl_logger.info("Done initializing models") + + # Get reference model device for KL transform + ref_devices = cfg.ref_model.get("devices", [2]) + ref_device = ref_devices[0] # Use first device for KL transform + + # Setup environment + if cfg.env.dataset == "gsm8k": + env = GSM8KEnv( + repeats=cfg.env.repeats, + tokenizer=train_tokenizer, + num_envs=cfg.env.num_envs, + ) + else: # ifeval + env = IFEvalEnv( + repeats=cfg.env.repeats, + tokenizer=train_tokenizer, + num_envs=cfg.env.num_envs, + ) + + env = env.append_transform( + KLRewardTransform( + actor=ref_model, + coef=cfg.policy.kl_coef, + device=torch.device(f"cuda:{ref_device}"), + add_to_reward=not cfg.train.kl_coef_in_loss, + ) + ) + + # Setup replay buffer + rb = ReplayBuffer( + storage=LazyStackStorage(cfg.train.steps_per_batch), + sampler=SamplerWithoutReplacement(), + batch_size=cfg.train.optim_batch_size, + ) + rb.append_transform(MCAdvantage(grpo_size=cfg.env.repeats)) + + # Setup collector + model_metadata = { + k: (v.dtype, v.shape) + for k, v in policy_training.model.merge_and_unload().state_dict().items() + } + updater = vLLMUpdater( + master_address=None, + master_port=None, + model_metadata=model_metadata, + ) + + collector = LLMCollector( + env, + policy=policy, + dialog_turns_per_batch=cfg.train.steps_per_batch, + total_dialog_turns=cfg.train.total_dialog_turns, + weight_updater=updater, + ) + updater.maybe_init_group() + + # Get training device for batch processing + train_devices = cfg.train_model.get("devices", [0]) + train_device = train_devices[0] # Use first device for batch processing + + # Setup loss and optimizer + loss_fn = GRPOLoss( + actor_network=policy_training, + kl_to_ref_coeff=cfg.policy.kl_coef if cfg.train.kl_coef_in_loss else 0.0, + ) + if cfg.model.compile: + loss_fn = torch.compile(loss_fn) + + optim = getattr(torch.optim, cfg.train.optimizer.name)( + policy_training.model.parameters(), + lr=cfg.train.optimizer.lr, + foreach=len(train_devices) == 1, + ) + + # Only use GradScaler with float16, not with bfloat16 + use_grad_scaling = ( + cfg.train.mixed_precision and cfg.train_model.torch_dtype == "float16" + ) + scaler = GradScaler(enabled=use_grad_scaling) + + # Setup logging + experiment_name = ( + cfg.logging.experiment_name + or f"{cfg.model.name.split('/')[-1]}_{cfg.env.dataset}" + ) + wandb_logger = WandbLogger(exp_name=experiment_name, config=dict(cfg)) + + # Create checkpoint directory + checkpoint_dir = Path(cfg.logging.checkpoint_dir) / experiment_name + checkpoint_dir.mkdir(parents=True, exist_ok=True) + + # Initialize weights + torchrl_logger.info("Initializing weights...") + # Ensure weights are on the first training device for vLLM + state_dict = TensorDict(policy_training.model.merge_and_unload().state_dict()).to( + f"cuda:{train_device}" + ) + collector.update_policy_weights_(state_dict, worker_ids=[0]) + del state_dict + torch.cuda.empty_cache() + gc.collect() + + # Training loop + for i, trajs in enumerate(collector): + torchrl_logger.info(f"Collected batch {i}: {len(trajs)} trajectories") + + # Clear memory after collection + torch.cuda.empty_cache() + gc.collect() + + trajs = trajs.reshape(-1) + rb.extend(trajs) + + # Calculate metrics + with torch.no_grad(): + reward = torch.cat(rb[:].get(("next", "reward"), as_list=True)).mean() + kl_penalty = torch.cat( + rb[:].get(("next", "kl_penalty"), as_list=True) + ).mean() + seq_length = torch.tensor( + [t.numel() for t in rb[:].get("tokens_response", as_list=True)], + dtype=torch.float, + ).mean() + metrics = { + "reward": float(reward), + "kl_penalty": float(kl_penalty), + "seq_length": float(seq_length), + } + + # Clear memory after metrics calculation + del trajs + torch.cuda.empty_cache() + gc.collect() + + if not reward: + torchrl_logger.info("No reward - skipping batch") + torch.cuda.empty_cache() + continue + + # Training epochs + for epoch in range(cfg.train.epochs): + torchrl_logger.info(f"Epoch {epoch}") + pbar = tqdm.tqdm(total=len(rb) // cfg.train.optim_batch_size) + + for batch_idx, batch in enumerate(rb): + # Move batch to device and clear CPU memory + batch = batch.to(train_device) + torch.cuda.empty_cache() + + pbar.update(1) + + # Forward pass + with torch.amp.autocast( + "cuda", + enabled=cfg.train.mixed_precision, + dtype=getattr(torch, cfg.train_model.torch_dtype), + ): + loss = loss_fn(batch) + loss_val = loss.mean(reduce=True) + loss_val = loss_val / cfg.train.gradient_accumulation_steps + + # Store metrics before clearing memory + metrics.update( + { + "ESS": float(loss.ESS), + "loss_objective": float(loss.loss_objective), + "clip_fraction": float(loss.clip_fraction), + "kl_approx": float(loss.kl_approx), + "entropy": float(loss.loss_entropy.mean()), + "kl_to_ref": float(loss.kl_to_ref.mean()), + "loss_kl_to_ref": float(loss.loss_kl_to_ref.mean()), + } + ) + + # Clear intermediate tensors + del loss + torch.cuda.empty_cache() + + # Backward pass with gradient scaling only for float16 + if use_grad_scaling: + scaler.scale(loss_val).backward() + else: + loss_val.backward() + + if (batch_idx + 1) % cfg.train.gradient_accumulation_steps == 0: + # Clip gradients + if use_grad_scaling: + scaler.unscale_(optim) + + grad_norm = torch.nn.utils.clip_grad_norm_( + policy_training.model.parameters(), + cfg.train.optimizer.clip_grad_norm, + ) + + # Optimizer step with or without scaler + if use_grad_scaling: + scaler.step(optim) + scaler.update() + else: + optim.step() + + optim.zero_grad() + + # Clear memory after optimization step + del loss_val + torch.cuda.empty_cache() + gc.collect() + + # Log metrics + for name, value in metrics.items(): + wandb_logger.log_scalar(name, value) + wandb_logger.log_scalar("grad_norm", float(grad_norm)) + + pbar.close() + # Clear memory after each epoch + torch.cuda.empty_cache() + gc.collect() + + # Update policy weights + torchrl_logger.info("Updating policy weights...") + # Ensure weights are on the first training device for vLLM + state_dict = TensorDict( + policy_training.model.merge_and_unload().state_dict() + ).to(f"cuda:{train_device}") + collector.update_policy_weights_( + policy_weights=state_dict, + worker_ids=[0], + ) + del state_dict + + # Clear memory after weight update + torch.cuda.empty_cache() + gc.collect() + + # Save checkpoint + if (i + 1) % cfg.logging.checkpoint_frequency == 0: + torchrl_logger.info( + f"Saving checkpoint {(i+1) // cfg.logging.checkpoint_frequency}..." + ) + checkpoint = { + "batch": i, + "model_state_dict": policy_training.model.state_dict(), + "optimizer_state_dict": optim.state_dict(), + "scaler_state_dict": scaler.state_dict(), + "config": dict(cfg), + } + torch.save(checkpoint, checkpoint_dir / f"checkpoint_{i:04d}.pt") + + +if __name__ == "__main__": + # Setup environment + setup_environment() + train() diff --git a/sota-implementations/grpo/grpo_utils.py b/sota-implementations/grpo/grpo_utils.py new file mode 100644 index 00000000000..981dd9fff33 --- /dev/null +++ b/sota-implementations/grpo/grpo_utils.py @@ -0,0 +1,355 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import os +from typing import Any, Literal + +import torch +from omegaconf import DictConfig +from torch import device as torch_device, dtype as torch_dtype + +from torchrl import logger as torchrl_logger +from torchrl.modules.llm import TransformersWrapper, vLLMWrapper +from transformers.models.auto.modeling_auto import AutoModelForCausalLM +from transformers.tokenization_utils import PreTrainedTokenizer + + +def get_train_model( + cfg: DictConfig, +) -> tuple[TransformersWrapper, PreTrainedTokenizer]: + """Creates and configures the training model with LoRA adapters. + + This function initializes the main training model with LoRA adapters and other + training-specific configurations like gradient checkpointing. The model is wrapped + in a TransformersWrapper for policy training. + + Args: + cfg (DictConfig): The hydra configuration object containing model and training settings. + Expected to have train_model section with LoRA, quantization, and other + training-specific parameters. + + Returns: + tuple[TransformersWrapper, PreTrainedTokenizer]: + - policy_training: The wrapped training model + - train_tokenizer: The tokenizer for the model + + Raises: + RuntimeError: If CUDA is not available or if device allocation fails + """ + torchrl_logger.info("Creating train model") + + # Set model dtype explicitly + model_dtype = getattr(torch, cfg.train_model.torch_dtype) + + # Get configured devices or default to [0] + train_devices = cfg.train_model.get("devices", [0]) + + # Create max_memory dict - set 0 memory for GPUs we don't want to use + max_memory = {} + for i in range(torch.cuda.device_count()): + if i in train_devices: + max_memory[i] = "24GiB" # Allow max memory for devices we want to use + else: + max_memory[i] = "0GiB" # No memory for other devices + max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback + + # Let HF handle distribution with max_memory + device_map = "balanced" if len(train_devices) > 1 else f"cuda:{train_devices[0]}" + train_model, train_tokenizer = get_hf_model( + cfg.model.name, + device_map=device_map, + max_memory=max_memory, + lora=cfg.train_model.lora.enabled, + lora_r=cfg.train_model.lora.r, + lora_alpha=cfg.train_model.lora.alpha, + lora_dropout=cfg.train_model.lora.dropout, + gradient_checkpointing=cfg.train_model.gradient_checkpointing, + quantize=cfg.train_model.quantization.enabled, + torch_dtype=model_dtype, + attn_implementation=cfg.train_model.attn_implementation, + compile=cfg.model.compile, + ) + + # Force all model parameters to the same dtype + for param in train_model.parameters(): + param.data = param.data.to(model_dtype) + + policy_training = TransformersWrapper( + train_model, + tokenizer=train_tokenizer, + from_text=False, + generate=False, + return_log_probs=True, + ) + return policy_training, train_tokenizer + + +def get_inference_model(cfg: DictConfig) -> vLLMWrapper: + """Creates the vLLM-based inference model for fast generation. + + This function initializes a vLLM model server for efficient inference and wraps + it in a vLLMWrapper for policy inference. vLLM provides optimized generation + with better throughput than standard HuggingFace generation. + + Args: + cfg (DictConfig): The hydra configuration object containing model settings. + Expected to have inference_model section with vLLM-specific parameters + like gpu_memory_utilization and generation settings. + + Returns: + vLLMWrapper: The wrapped vLLM model ready for inference. + + Raises: + AssertionError: If the vLLM server or model initialization fails + """ + from torchrl.modules.llm.backends.vllm import make_vllm_worker + + vllm_devices = cfg.inference_model.get("devices", [1]) + torchrl_logger.info(f"Creating inference model on devices {vllm_devices}") + + model_name = cfg.model.name + + # vLLM handles device mapping internally + inference_server = make_vllm_worker( + model_name, + gpu_memory_utilization=cfg.inference_model.gpu_memory_utilization, + devices=list(vllm_devices), # Convert to list for type compatibility + make_ray_worker=True, + ) + assert inference_server is not None + policy = vLLMWrapper( + inference_server, + from_text=True, + return_log_probs=True, + generate_kwargs={ + "max_tokens": cfg.inference_model.max_tokens, + "include_stop_str_in_output": cfg.inference_model.include_stop_str_in_output, + "temperature": cfg.inference_model.temperature, + }, + ) + assert policy.model is not None + return policy + + +def get_ref_model( + cfg: DictConfig, tokenizer: PreTrainedTokenizer +) -> TransformersWrapper: + """Creates the reference model for KL penalty computation. + + This function initializes a frozen copy of the base model to serve as the + reference model for KL divergence computation. The reference model is typically + quantized and does not require gradient computation. + + Args: + cfg (DictConfig): The hydra configuration object containing model settings. + Expected to have ref_model section with quantization and attention settings. + tokenizer (PreTrainedTokenizer): The tokenizer to use with the reference model. + + Returns: + TransformersWrapper: The wrapped reference model in eval mode with detached weights. + """ + from tensordict import TensorDict + + torchrl_logger.info("Creating ref model") + + # Get configured devices or default to [2] + ref_devices = cfg.ref_model.get("devices", [2]) + + # Create max_memory dict - set 0 memory for GPUs we don't want to use + max_memory = {} + for i in range(torch.cuda.device_count()): + if i in ref_devices: + max_memory[i] = "24GiB" # Allow max memory for devices we want to use + else: + max_memory[i] = "0GiB" # No memory for other devices + max_memory["cpu"] = "24GiB" # Allow CPU memory as fallback + + # Let HF handle distribution with max_memory + device_map = "balanced" if len(ref_devices) > 1 else f"cuda:{ref_devices[0]}" + model_name = cfg.model.name + + ref_model = get_hf_model( + model_name, + device_map=device_map, + max_memory=max_memory, + torch_dtype=getattr(torch, cfg.ref_model.torch_dtype), + quantize=cfg.ref_model.quantization.enabled, + gradient_checkpointing=cfg.ref_model.gradient_checkpointing, + attn_implementation=cfg.ref_model.attn_implementation, + lora=False, # Reference model doesn't need LoRA + requires_grad=False, + )[0].eval() + # Detach weights + TensorDict.from_module(ref_model).data.to_module(ref_model) + ref_model = TransformersWrapper( + ref_model, + tokenizer=tokenizer, + from_text=False, + generate=False, + return_log_probs=True, + ) + return ref_model + + +def get_hf_model( + model_name: str, + torch_dtype: torch_dtype = torch.float32, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.1, + quantize: bool = False, + fsdp: str = "", + fsdp_config: Any = None, + gradient_checkpointing: bool = True, + device_map: str + | dict[str, int | str | torch_device] + | int + | torch_device + | None = None, + lora: bool = True, + attn_implementation: Literal["flash_attention_2", "flex_attention", "sdpa"] + | None = "flex_attention", + requires_grad: bool = True, + compile: bool = False, + max_memory: dict[str, str] | None = None, +) -> tuple[AutoModelForCausalLM, PreTrainedTokenizer]: + """Creates and configures a HuggingFace model with optional optimizations. + + Args: + model_name (str): HuggingFace model identifier (e.g., "Qwen/Qwen2.5-3B") + torch_dtype (torch.dtype, optional): Model precision. Default: torch.float32 + lora_r (int, optional): LoRA rank - controls capacity of adaptations. Default: 8 + lora_alpha (int, optional): LoRA alpha - scales the adaptations. Default: 16 + lora_dropout (float, optional): Dropout probability for LoRA layers. Default: 0.1 + quantize (bool, optional): Whether to enable 4-bit quantization. Default: False + fsdp (str, optional): Fully Sharded Data Parallel configuration. Default: "" + fsdp_config (Any, optional): Additional FSDP configurations. Default: None + gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Default: True + device_map (str | dict | int | torch.device | None, optional): Device placement strategy. Default: None + lora (bool, optional): Whether to apply LoRA adapters. Default: True + attn_implementation (Literal["flash_attention_2", "flex_attention", "sdpa"] | None, optional): + Attention implementation to use. Default: "flex_attention" + requires_grad (bool, optional): Whether to enable gradient computation. Default: True + compile (bool, optional): Whether to enable model compilation. Default: False + max_memory (dict[str, str], optional): Memory configuration for distributed training. Default: {} + + Returns: + tuple[AutoModelForCausalLM, PreTrainedTokenizer]: + - model: The configured HuggingFace model + - tokenizer: The associated tokenizer + + Raises: + ImportError: If required dependencies are not installed + RuntimeError: If model initialization fails + """ + from transformers import AutoModelForCausalLM, AutoTokenizer + + if max_memory is None: + max_memory = {} + + tokenizer = AutoTokenizer.from_pretrained(model_name) + if tokenizer.pad_token == tokenizer.eos_token: + tokenizer.pad_token = "PAD" + tokenizer.padding_side = "left" + + # Configure model settings for mixed precision + # Store original dtype to restore it later + original_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch_dtype) + + model_configs = { + "torch_dtype": torch_dtype, + "device_map": device_map if device_map is not None else "auto", + "max_memory": max_memory, + } + if torch.cuda.is_available() and attn_implementation: + torchrl_logger.info(f"{attn_implementation} init") + model_configs["attn_implementation"] = attn_implementation + + try: + # Configure training settings based on FSDP usage + if fsdp != "" and fsdp_config is not None: + torchrl_logger.info("Configurations for FSDP") + bnb_config_params = {"bnb_4bit_quant_storage": torch_dtype} + else: + bnb_config_params = {} + + # Enable Quantization + if quantize: + try: + from transformers.utils.quantization_config import BitsAndBytesConfig + except ImportError: + raise ImportError( + "Please install transformers with bitsandbytes support" + ) + + bnb_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=torch_dtype, + **bnb_config_params, + ) + model_configs["quantization_config"] = bnb_config + + model = AutoModelForCausalLM.from_pretrained( + model_name, + trust_remote_code=True, + use_cache=not gradient_checkpointing, + cache_dir="/tmp/.cache", + **model_configs, + ) + + # Configure gradient checkpointing based on FSDP usage + if fsdp == "" and fsdp_config is None: + if gradient_checkpointing: + torchrl_logger.info("gradient_checkpointing enabled") + model.gradient_checkpointing_enable() + else: + if gradient_checkpointing: + torchrl_logger.info("gradient_checkpointing enabled") + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + + if lora: + try: + from peft import get_peft_model, LoraConfig + except ImportError: + raise ImportError("Please install peft: pip install peft") + + # Create LoRA config with explicit dtype setting + lora_config = LoraConfig( + r=lora_r, + lora_alpha=lora_alpha, + target_modules="all-linear", + lora_dropout=lora_dropout, + bias="none", + task_type="CAUSAL_LM", + inference_mode=False, + init_lora_weights=True, # This ensures weights are initialized + ) + + # Initialize LoRA model + model = get_peft_model( + model, + lora_config, + autocast_adapter_dtype=False, # Prevent automatic casting of adapter layers + ) + + # Force LoRA layers to correct dtype + for n, p in model.named_parameters(): + if "lora_" in n: # Only convert LoRA parameters + p.data = p.data.to(torch_dtype) + + if requires_grad: + model.requires_grad_(True) + + return model, tokenizer + + finally: + # Restore original dtype + torch.set_default_dtype(original_dtype) diff --git a/sota-implementations/grpo/requirements_gsm8k.txt b/sota-implementations/grpo/requirements_gsm8k.txt new file mode 100644 index 00000000000..4a6182cdf52 --- /dev/null +++ b/sota-implementations/grpo/requirements_gsm8k.txt @@ -0,0 +1,13 @@ +torch==2.7.0 +transformers==4.52.4 +peft==0.15.2 +bitsandbytes==0.46.0 +datasets==3.6.0 +wandb==0.19.11 +hydra-core==1.3.2 +ray==2.46.0 +tqdm==4.67.1 +tensordict==0.9.0 +vllm==0.9.0.1 +accelerate==1.7.0 +xformers==0.0.30 diff --git a/sota-implementations/grpo/requirements_ifeval.txt b/sota-implementations/grpo/requirements_ifeval.txt new file mode 100644 index 00000000000..dbd2735d979 --- /dev/null +++ b/sota-implementations/grpo/requirements_ifeval.txt @@ -0,0 +1,16 @@ +torch==2.7.0 +transformers==4.52.4 +peft==0.15.2 +bitsandbytes==0.46.0 +datasets==3.6.0 +wandb==0.19.11 +hydra-core==1.3.2 +ray==2.46.0 +tqdm==4.67.1 +tensordict==0.9.0 +vllm==0.9.0.1 +accelerate==1.7.0 +xformers==0.0.30 +nltk==3.9.1 +langdetect==1.0.9 +immutabledict==4.2.1 diff --git a/test/llm/test_envs.py b/test/llm/test_envs.py index c7a7bc17aa5..e0fa44d8ea2 100644 --- a/test/llm/test_envs.py +++ b/test/llm/test_envs.py @@ -48,6 +48,15 @@ ) +@pytest.fixture(scope="module", autouse=True) +def list_to_stack_fixture(): + import tensordict + + with tensordict.set_list_to_stack(True): + yield + return + + @pytest.fixture(scope="session", autouse=True) def set_list_to_stack_for_test(): with set_list_to_stack(True): @@ -463,7 +472,7 @@ def ref_model(self): tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") model = OPTForCausalLM.from_pretrained("facebook/opt-125m").eval() - tokenizer.pad_token = tokenizer.eos_token + tokenizer.pad_token = "<|PAD|>" tokenizer.padding_side = "left" yield model, tokenizer @@ -498,6 +507,7 @@ def test_env_reward(self, n_envs): @pytest.mark.skipif(not _has_transformers, reason="requires transformers library") @pytest.mark.parametrize("n_envs", [1, 4]) def test_kl_bonus(self, n_envs, ref_model): + torch.manual_seed(0) ref_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") with torch.device(ref_device): @@ -516,6 +526,8 @@ def test_kl_bonus(self, n_envs, ref_model): generate=True, from_text=True, tokenizer=tokenizer, + generate_kwargs={"max_new_tokens": 20}, + tokenizer_kwargs={"add_special_tokens": False}, ) env = make_gsm8k_env(num_envs=n_envs, tokenizer=tokenizer) diff --git a/torchrl/_utils.py b/torchrl/_utils.py index aab62c68dcf..685e2208f6e 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -82,9 +82,10 @@ def format(self, record): return formatted_message -console_handler = logging.StreamHandler(stream_handler) +console_handler = logging.StreamHandler(stream=stream_handler) console_handler.setFormatter(_CustomFormatter()) logger.addHandler(console_handler) +console_handler.setLevel(logging.INFO) VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG)))) _os_is_windows = sys.platform == "win32" diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 13dfec8b9db..b1b4205e7b7 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -6,6 +6,7 @@ from .llm import ( AdaptiveKLController, ConstantKLController, + ContentBase, create_infinite_iterator, get_dataloader, History, @@ -114,6 +115,7 @@ "BoundedTensorSpec", "Categorical", "Choice", + "ContentBase", "Composite", "CompositeSpec", "ConstantKLController", diff --git a/torchrl/data/llm/__init__.py b/torchrl/data/llm/__init__.py index 6e74503bc41..e0d548a2f23 100644 --- a/torchrl/data/llm/__init__.py +++ b/torchrl/data/llm/__init__.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .chat import History +from .chat import ContentBase, History from .common import LLMData from .dataset import ( create_infinite_iterator, @@ -17,8 +17,9 @@ __all__ = [ "AdaptiveKLController", - "History", "ConstantKLController", + "ContentBase", + "History", "LLMData", "PairwiseDataset", "PromptData", diff --git a/torchrl/data/llm/chat.py b/torchrl/data/llm/chat.py index 38c618bdb94..8150c977773 100644 --- a/torchrl/data/llm/chat.py +++ b/torchrl/data/llm/chat.py @@ -276,7 +276,7 @@ def apply_chat_template( @classmethod def from_text( cls, - text: str, + text: str | list[str], chat_template_name: Literal["chatml_format", "qwen"] | None = None, chat_template: str | None = None, ) -> History: diff --git a/torchrl/envs/async_envs.py b/torchrl/envs/async_envs.py index 425e93494c4..971dfc02013 100644 --- a/torchrl/envs/async_envs.py +++ b/torchrl/envs/async_envs.py @@ -226,6 +226,11 @@ def __init__( super().__init__(batch_size=[self.num_envs]) self._busy = set() + @property + def env_batch_sizes(self) -> list[torch.Size]: + """Returns the batch-sizes of every env.""" + raise NotImplementedError + def _reset( self, tensordict: TensorDictBase | None = None, @@ -236,10 +241,15 @@ def _reset( if tensordict is None: if self._stack_func in ("lazy_stack", "maybe_dense"): tensordict = LazyStackedTensorDict( - *[TensorDict() for _ in range(self.num_envs)] + *[ + TensorDict(batch_size=self.env_batch_sizes[i]) + for i in range(self.num_envs) + ] ) else: - tensordict = TensorDict(batch_size=self.num_envs) + tensordict = TensorDict( + batch_size=(self.num_envs,) + self.env_batch_sizes[0] + ) tensordict.set(self._env_idx_key, torch.arange(tensordict.shape[0])) self._async_private_reset_send(tensordict) tensordict = self._async_private_reset_recv(min_get=self.num_envs) @@ -285,10 +295,15 @@ def reset( if tensordict is None: if self._stack_func in ("lazy_stack", "maybe_dense"): tensordict = LazyStackedTensorDict( - *[TensorDict() for _ in range(self.num_envs)] + *[ + TensorDict(batch_size=self.env_batch_sizes[i]) + for i in range(self.num_envs) + ] ) else: - tensordict = TensorDict(batch_size=self.num_envs) + tensordict = TensorDict( + batch_size=(self.num_envs,) + self.env_batch_sizes[0] + ) tensordict.set(self._env_idx_key, torch.arange(tensordict.shape[0])) self.async_reset_send(tensordict) tensordict = self.async_reset_recv(min_get=self.num_envs) @@ -318,23 +333,30 @@ def _maybe_make_tensordict(self, tensordict, env_index, make_if_none): if isinstance(env_idx, torch.Tensor): env_idx = env_idx.tolist() if isinstance(env_idx, int): + # If we squeezed a td with shape (1,) and got a NonTensorStack -> NonTensorData, then + # unsqueezed the NonTensorData, we'd still have a NonTensorData with shape (1,) + # This will give us an integer now, but we don't want to unsqueeze the full td because then + # we'd have a td with shape (1, 1) + if tensordict.shape != (1, *self.env_batch_sizes[env_idx]): + tensordict = tensordict.unsqueeze(0) env_idx = [env_idx] - tensordict = tensordict.unsqueeze(0) elif isinstance(env_index, int): if make_if_none: if tensordict is None: - tensordict = TensorDict(batch_size=(), device=self.device) + tensordict = TensorDict( + batch_size=self.env_batch_sizes[env_index], device=self.device + ) if self.stack in ("lazy_stack", "maybe_dense"): tensordict = tensordict.unsqueeze(0) else: - tensordict = LazyStackedTensorDict(tensordict) + tensordict = lazy_stack([tensordict]) tensordict[self._env_idx_key] = NonTensorStack(env_index) env_idx = [env_index] else: if make_if_none and tensordict is None: if self.stack in ("lazy_stack", "maybe_dense"): - tensordict = LazyStackedTensorDict( - *[TensorDict(device=self.device) for _ in env_index] + tensordict = lazy_stack( + [TensorDict(device=self.device) for _ in env_index] ) else: tensordict = TensorDict( @@ -462,6 +484,16 @@ def _setup(self) -> None: input_spec = specs["input_spec"] return output_spec, input_spec + @property + def env_batch_sizes(self) -> list[torch.Size]: + batch_sizes = getattr(self, "_env_batch_sizes", []) + if not batch_sizes: + for _env_idx in range(self.num_envs): + self.input_queue[_env_idx].put(("batch_size", None)) + batch_sizes.append(self.output_queue[_env_idx].get()) + self._env_batch_sizes = batch_sizes + return batch_sizes + def async_step_send( self, tensordict: TensorDictBase, env_index: int | list[int] | None = None ) -> None: @@ -641,6 +673,8 @@ def _env_exec( msg, data = msg_data if msg == "get_specs": output_queue.put(env.specs) + elif msg == "batch_size": + output_queue.put(env.batch_size) elif msg == "reset": data = env.reset(data.copy()) data.set(cls._env_idx_key, NonTensorData(i)) @@ -711,6 +745,10 @@ def _setup(self) -> None: specs = torch.stack([env.specs for env in self.envs]) return specs["output_spec"].clone(), specs["input_spec"].clone() + @property + def env_batch_sizes(self) -> list[torch.Size]: + return [env.batch_size for env in self.envs] + @classmethod def _get_specs(cls, env: EnvBase): return env.specs diff --git a/torchrl/envs/llm/chat.py b/torchrl/envs/llm/chat.py index ba3616c8e3d..08fcccc499e 100644 --- a/torchrl/envs/llm/chat.py +++ b/torchrl/envs/llm/chat.py @@ -209,37 +209,25 @@ def _reset(self, tensordict: TensorDictBase | None): content = [content for _ in range(s)] # FIXME: Assume the text is not formatted and this is just content - content = TensorDict( - role=self.user_role, content=content, batch_size=self.batch_size - ) - content["role"] = content.get("role").maybe_to_stack() + role = self.user_role + for s in reversed(self.batch_size): + role = [role for _ in range(s)] + content = History(role=role, content=content, batch_size=self.batch_size) if self.system_prompt is not None: - history = TensorDict( - { - "role": self.system_role, - "content": self.system_prompt, - } - ) + role = self.system_role + system_prompt = self.system_prompt for s in reversed(self.batch_size): - history = [history for _ in range(s)] - history = TensorDict(history=history, batch_size=self.batch_size) - role = lazy_stack( - [ - history.get(("history", "role")), - content.get("role"), - ], - -1, - ) - content = lazy_stack( - [ - history.get(("history", "content")), - content.get("content"), - ], - -1, - ) - td = TensorDict(role=role, content=content, batch_size=self.batch_size + (2,)) + role = [role for _ in range(s)] + system_prompt = [self.system_prompt for _ in range(s)] + history = History( + role=role, + content=system_prompt, + batch_size=self.batch_size, + ) + history = lazy_stack([history, content], -1) + else: + history = content.unsqueeze(-1) # Extract history - history = History.from_tensordict(td) result = lazy_stack( list( TensorDict( diff --git a/torchrl/envs/llm/libs/mlgym.py b/torchrl/envs/llm/libs/mlgym.py index bf2e8aed38d..cd1adf1eb0d 100644 --- a/torchrl/envs/llm/libs/mlgym.py +++ b/torchrl/envs/llm/libs/mlgym.py @@ -26,7 +26,6 @@ from torchrl.data.llm import History from torchrl.envs import ConditionalSkip, GymWrapper, Transform, TransformedEnv - # Inv transforms: # Transforms to apply prior to pass the model output to the env @@ -860,4 +859,7 @@ def make_mlgym( env.append_transform(MessageToHistory()) env.append_transform(TemplateTransform(tokenizer=tokenizer)) env.append_transform(MLGymRewardAssignment()) + # # We want the env to have a batch-size of (1,) because it will be easier to interact with + # # LLMs + # env.append_transform(BatchSizeTransform(batch_size=(1,))) return env diff --git a/torchrl/envs/llm/reward/ifeval/_instructions.py b/torchrl/envs/llm/reward/ifeval/_instructions.py index 11995dbc94f..7098fd311a6 100644 --- a/torchrl/envs/llm/reward/ifeval/_instructions.py +++ b/torchrl/envs/llm/reward/ifeval/_instructions.py @@ -38,7 +38,6 @@ import string from typing import Any, Dict, Literal, Optional, Sequence, Union -import langdetect from torchrl._utils import logger as torchrl_logger from ._instructions_util import ( @@ -192,9 +191,11 @@ def check_following(self, value: str) -> bool: Returns: `True` if the language of `value` follows instruction; otherwise False. """ + import langdetect + try: return langdetect.detect(value) == self._language - except langdetect.LangDetectException as e: + except (langdetect.LangDetectException, ImportError) as e: # Count as instruction is followed. torchrl_logger.error( "Unable to detect language for text %s due to %s", value, e @@ -1500,7 +1501,7 @@ def build_description(self) -> str: ) return self._description_pattern - def get_instruction_args(self) -> dict[str, Any]: + def get_instruction_args(self) -> dict[str, Any] | None: return None def get_instruction_args_keys(self) -> list[str]: @@ -1509,9 +1510,11 @@ def get_instruction_args_keys(self) -> list[str]: def check_following(self, value: str) -> bool: """Checks that the response is in English and in all capital letters.""" + import langdetect + try: return value.isupper() and langdetect.detect(value) == "en" - except langdetect.LangDetectException as e: + except (langdetect.LangDetectException, ImportError) as e: # Count as instruction is followed. torchrl_logger.error( "Unable to detect language for text %s due to %s", value, e @@ -1539,9 +1542,11 @@ def get_instruction_args_keys(self) -> list[str]: def check_following(self, value: str) -> bool: """Checks that the response is in English and in all lowercase letters.""" + import langdetect + try: return value.islower() and langdetect.detect(value) == "en" - except langdetect.LangDetectException as e: + except (langdetect.LangDetectException, ImportError) as e: # Count as instruction is followed. torchrl_logger.error( "Unable to detect language for text %s due to %s", value, e diff --git a/torchrl/envs/llm/reward/ifeval/_scorer.py b/torchrl/envs/llm/reward/ifeval/_scorer.py index 5cabf7ef624..7f749f16b37 100644 --- a/torchrl/envs/llm/reward/ifeval/_scorer.py +++ b/torchrl/envs/llm/reward/ifeval/_scorer.py @@ -17,13 +17,15 @@ import importlib.util import re +from typing import Callable import torch -from jedi.inference.gradual.typing import Callable from tensordict import NestedKey, NonTensorData, TensorClass, TensorDict, TensorDictBase from tensordict.tensorclass import is_non_tensor +from torchrl._utils import logger as torchrl_logger + from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded from torchrl.envs import Transform @@ -147,6 +149,7 @@ def _step( response = tensordict.get(self.response_key) if is_non_tensor(response): response = response.data + torchrl_logger.info(f"{response=}") # TODO: This should be a distinct module # Regular expression patterns to match think and answer blocks diff --git a/torchrl/envs/llm/transforms/kl.py b/torchrl/envs/llm/transforms/kl.py index a02879c803e..b1fbb226a27 100644 --- a/torchrl/envs/llm/transforms/kl.py +++ b/torchrl/envs/llm/transforms/kl.py @@ -160,7 +160,8 @@ def _step( td_device, as_nested_tensor=True, layout=torch.strided ) else: - ref_log_prob = self.actor(tensordict).get(self.sample_log_prob_key) + ref_log_prob_td = self.actor(tensordict) + ref_log_prob = ref_log_prob_td.get(self.sample_log_prob_key) reward_key = self.in_keys[0] reward = next_tensordict.get(reward_key) @@ -178,13 +179,14 @@ def _step( [rew.expand(lp.shape) for rew, lp in zip(reward, ref_log_prob)], layout=torch.strided, ) - if ref_log_prob[0].shape != curr_log_prob[0].shape: - # Don't check shapes if nested - raise ValueError( - f"the log-probability tensor shapes must match, got cur_log_prob.shape={curr_log_prob[0].shape} and log_prob.shape={ref_log_prob[0].shape}. " - f"One possible reason is that the padding token is identical to the eos token, which means that the eos_token log_prob is truncated from the " - f"reference model output." - ) + for i in range(ref_log_prob.size(0)): + if ref_log_prob[i].shape != curr_log_prob[i].shape: + # Don't check shapes if nested + raise ValueError( + f"the log-probability tensor shapes must match, got cur_log_prob.shape={curr_log_prob[i].shape} and log_prob.shape={ref_log_prob[i].shape}. " + f"One possible reason is that the padding token is identical to the eos token, which means that the eos_token log_prob is truncated from the " + f"reference model output." + ) if reward is not None and reward.ndim != curr_log_prob.ndim: raise ValueError( "The number of dimensions of reward must be the same as the number of dimensions of the KL " diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index b96210fac44..3c237c39ec8 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import warnings + from copy import copy, deepcopy import torch @@ -95,6 +97,10 @@ def __init__( functional: bool | None = None, device: torch.device | None = None, ): + warnings.warn( + "This class is to be deprecated in favour of torchrl.envs.llm.KLRewardTransform. The two will be fused in v0.10 and this version will be removed in v0.11.", + category=DeprecationWarning, + ) if in_keys is None: in_keys = self.DEFAULT_IN_KEYS if out_keys is None: diff --git a/torchrl/modules/llm/backends/vllm.py b/torchrl/modules/llm/backends/vllm.py index a1d6ee7b5dc..9f8485ad28b 100644 --- a/torchrl/modules/llm/backends/vllm.py +++ b/torchrl/modules/llm/backends/vllm.py @@ -145,7 +145,12 @@ def __init__(self, *args, **kwargs): str(int(gpu_id)) for gpu_id in gpu_ids ) torch.cuda.set_device(0) # Since only one GPU is visible, it's cuda:0 - super().__init__(*args, device="cuda:0", **kwargs) + self.args = args + self.kwargs = kwargs + + def initialize(self): + super().__init__(*self.args, device="cuda:0", **self.kwargs) + return True def make_vllm_worker( @@ -176,6 +181,10 @@ def make_vllm_worker( ... ray.get(handle) """ if make_ray_worker: + if len(devices) > 1: + raise ValueError( + "ray-based instantiation of vLLM does not support multiple devices at the moment." + ) devices = [ torch.device(device).index if not isinstance(device, int) else device for device in devices @@ -190,7 +199,15 @@ def make_vllm_worker( if not ray.is_initialized(): ray.init() - pg = placement_group([{"GPU": 1, "CPU": 1}] * torch.cuda.device_count()) + pipeline_parallel_size = 1 + node_id = 0 + pg = placement_group( + [{"CPU": 1, "GPU": 1}] * torch.cuda.device_count(), + strategy="SPREAD" + if (pipeline_parallel_size and pipeline_parallel_size > 1) + else "STRICT_PACK", + _soft_target_node_id=node_id if pipeline_parallel_size is None else None, + ) ray.get(pg.ready()) scheduling_inference = PlacementGroupSchedulingStrategy( @@ -201,11 +218,12 @@ def make_vllm_worker( torchrl_logger.info( f"Create vLLM worker with {devices=}, {scheduling_inference=}" ) - return ray.remote( - num_gpus=len(devices), + worker_cls = ray.remote( + num_gpus=1, num_cpus=1, scheduling_strategy=scheduling_inference, - )(LLMOnDevice).remote( + )(LLMOnDevice) + worker = worker_cls.remote( model=model_name, # enforce_eager=True, dtype="bfloat16", @@ -215,6 +233,9 @@ def make_vllm_worker( enable_chunked_prefill=True, **kwargs, ) + ray.get(worker.initialize.remote()) + return worker + else: with _cuda_visible_devices(devices): return LLM( diff --git a/torchrl/modules/llm/policies/transformers_wrapper.py b/torchrl/modules/llm/policies/transformers_wrapper.py index 152bebc2dbb..6dc5105ed55 100644 --- a/torchrl/modules/llm/policies/transformers_wrapper.py +++ b/torchrl/modules/llm/policies/transformers_wrapper.py @@ -216,11 +216,14 @@ def forward( if not tensordict.ndim: # unsqueeze - squeeze the input try: - return self(lazy_stack([tensordict]))[0] + return self(lazy_stack([tensordict])).squeeze(0) except Exception as e: raise RuntimeError( f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional." ) from e + elif tensordict.ndim > 1: + return self(tensordict.reshape(-1)).view(tensordict.shape) + _source_device = None if self._device: _source_device = tensordict.device @@ -323,7 +326,10 @@ def _from_transformers_generate_text(self, td, out, cfg=None): _unpad_tensors(attention_mask, attention_mask, as_nested=False), ) if self.return_log_probs: - out.set(self.log_prob_key, log_probs) + out.set( + self.log_prob_key, + _unpad_tensors(log_probs, mask_sequences, as_nested=False), + ) out.set("logits", _unpad_tensors(logits, mask_sequences, as_nested=False)) return out @@ -376,7 +382,10 @@ def _from_transformers_generate_tokens(self, td, out, cfg=None): _unpad_tensors(attention_mask, attention_mask, as_nested=False), ) if self.return_log_probs: - out.set(self.log_prob_key, log_probs) + out.set( + self.log_prob_key, + _unpad_tensors(log_probs, mask_sequences, as_nested=False), + ) out.set("logits", _unpad_tensors(logits, mask_sequences, as_nested=False)) return out @@ -483,6 +492,8 @@ def _from_transformers_logprobs_tokens(self, td, out, cfg=None): log_probs, logits = self._log_probs_from_logits( total_tokens_out, response_input_ids, pad_val=pad_val ) + # for i in range(log_probs.size(0)): + # assert log_probs[i].shape[-1] == response_input_ids[i].shape[-1] out.set("logits", logits) out.set(self.log_prob_key, log_probs) diff --git a/torchrl/modules/llm/policies/vllm_wrapper.py b/torchrl/modules/llm/policies/vllm_wrapper.py index 1f9f796c263..0385edaaa45 100644 --- a/torchrl/modules/llm/policies/vllm_wrapper.py +++ b/torchrl/modules/llm/policies/vllm_wrapper.py @@ -243,11 +243,14 @@ def forward( if not tensordict.ndim: # unsqueeze - squeeze the input try: - return self(lazy_stack([tensordict]))[0] + return self(lazy_stack([tensordict])).squeeze(0) except Exception as e: raise RuntimeError( f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional." ) from e + elif tensordict.ndim > 1: + return self(tensordict.reshape(-1)).view(tensordict.shape) + if kwargs: sampling_params = self.sampling_params.clone() for key, val in kwargs.items():