Skip to content

TorchRL 0.9.0 Release Notes

Latest
Compare
Choose a tag to compare
@vmoens vmoens released this 10 Jul 15:28
· 1 commit to main since this release
7e8f940

We are excited to announce the release of TorchRL 0.9.0! This release introduces a comprehensive LLM API for language model fine-tuning, extensive torch.compile compatibility across all algorithms, and numerous performance improvements.

🚀 Major Features

🤖 LLM API - Complete Framework for Language Model Fine-tuning

TorchRL now includes a comprehensive LLM API for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:

The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out the complete documentation and GRPO implementation example to get started!

Unified LLM Wrappers

  • TransformersWrapper: Seamless integration with Hugging Face models
  • vLLMWrapper: High-performance inference with vLLM engines
  • Consistent API: Both wrappers provide unified input/output interfaces using TensorClass objects
  • Multiple input modes: Support for history, text, and tokenized inputs
  • Configurable outputs: Text, tokens, masks, and log probabilities

Advanced Conversation Management

  • History class: Advanced bidirectional conversation management with automatic chat template detection
  • Multi-model support: Automatic template detection for various model families (Qwen, DialoGPT, Falcon, DeepSeek, etc.)
  • Assistant token masking: Identify which tokens were generated by the assistant for RL applications
  • Tool calling support: Handle function calls and tool responses in conversations
  • Batch operations: Efficient tensor operations for processing multiple conversations

🛠️ Tool Integration

  • PythonInterpreter transform: Built-in Python code execution capabilities
  • MCPToolTransform: General tool calling support
  • Extensible architecture: Easy to add custom tool transforms
  • Safe execution: Controlled environment for tool execution

🎯 Specialized Objectives

  • GRPOLoss: Group Relative Policy Optimization loss function optimized for language models
  • SFTLoss: Supervised fine-tuning loss with assistant token masking support
  • MCAdvantage: Monte-Carlo advantage estimation for LLM training
  • KL divergence rewards: Built-in KL penalty computation

⚡ High-Performance Collectors

  • LLMCollector: Async data collection with distributed training support
  • RayLLMCollector: Multi-node distributed collection using Ray
  • Weight synchronization: Automatic model weight updates across distributed setups
  • Trajectory management: Efficient handling of variable-length conversations

🔄 Flexible Environments

  • ChatEnv: Transform-based architecture for conversation management
  • Transform-based rewards: Modular reward computation and data loading
  • Dataset integration: Built-in support for loading prompts from datasets
  • Thinking prompts: Chain-of-thought reasoning support

📚 Complete Implementation Example

A full GRPO implementation is provided in sota-implementations/grpo/ with:

  • Multi-GPU support with efficient device management
  • Mixed precision training
  • Gradient accumulation
  • Automatic checkpointing
  • Comprehensive logging with Weights & Biases
  • Hydra configuration system
  • Asynchronous training support with Ray

🆕 New Features

LLM API Components

  • LLMMaskedCategorical (#3041) - Categorical distribution with masking for LLM token selection
  • AddThinkingPrompt transform (#3027) - Add chain-of-thought reasoning prompts
  • MCPToolTransform (#2993) - Model Context Protocol tool integration
  • PythonInterpreter transform (#2988) - Python code execution in LLM environments
  • ContentBase (#2985) - Base class for structured content in LLM workflows
  • LLM Tooling (#2966) - Comprehensive tool integration framework
  • History API (#2965) - Advanced conversation management system
  • LLM collector (#2879) - Specialized data collection for language models
  • vLLM wrapper (#2830) - High-performance vLLM integration
  • Transformers policy (#2825) - Hugging Face transformers integration

Environment Enhancements

  • IsaacLab wrapper (#2937) - NVIDIA Isaac Lab environment support
  • Complete PettingZooWrapper state support (#2953) - Full state management for multi-agent environments
  • ConditionalPolicySwitch transform (#2711) - Dynamic policy switching based on conditions
  • Async environments (#2864) - Asynchronous environment execution
  • VecNormV2 (#2867) - Improved vector normalization with batched environment support

Algorithm Improvements

  • Async GRPO (#2997) - Asynchronous Group Relative Policy Optimization
  • Expert Iteration and SFT (#3017) - Expert iteration and supervised fine-tuning algorithms
  • Async SAC (#2946) - Asynchronous Soft Actor-Critic implementation
  • Multi-node Ray support for GRPO (#3040) - Distributed GRPO training

Data Management

  • RayReplayBuffer (#2835) - Distributed replay buffer using Ray
  • RayReplayBuffer usage examples (#2949) - Comprehensive usage examples
  • Policy factory for collectors (#2841) - Flexible policy creation in collectors
  • Local and Remote WeightUpdaters (#2848) - Distributed weight synchronization

Performance Optimizations

  • Deactivate vmap in objectives (#2957) - Improved performance by disabling vectorized operations
  • Hold a single copy of low/high in bounded specs (#2977) - Memory optimization for bounded specifications
  • Use TensorDict._new_unsafe in step (#2905) - Performance improvement in environment steps
  • Memoize calls to encode and related methods (#2907) - Caching for improved performance

Utility Features

  • Compose.pop (#3026) - Remove transforms from composition
  • Add optional Explained Variance logging (#3010) - Enhanced logging capabilities
  • Enabling worker level control on frames_per_batch (#3020) - Granular control over data collection
  • collector.start() (#2935) - Explicit collector lifecycle management
  • Timer transform (#2806) - Timing capabilities for environments
  • MultiAction transform (#2779) - Multi-action environment support
  • Transform for partial steps (#2777) - Partial step execution support

🔧 Performance Improvements

  • VecNormV2: Improved vector normalization with better bias correction timing (#2900, #2901)
  • MaskedCategorical cross_entropy: Faster loss computation (#2882)
  • Avoid padding in transformer wrapper: Memory and performance optimization (#2881)
  • Set padded token log-prob to 0.0: Improved numerical stability (#2857)
  • Better device checks: Enhanced device management (#2909)
  • Local dtype maps: Optimized dtype handling (#2936)

🐛 Bug Fixes

LLM API Fixes

  • Variable length vllm wrapper answer stacking (#3049) - Fixed stacking issues with variable-length responses
  • LLMCollector trajectory collection methods (#3018) - Fixed trajectory collection when multiple trajectories complete simultaneously
  • Fix IFEval GRPO runs (#3012) - Resolved issues with IFEval dataset runs
  • Fix cuda cache empty in GRPO scripts (#3016) - Memory management improvements
  • Right log-prob size in transformer wrapper (#2856) - Fixed log probability tensor sizing
  • Fix gc import (#2862) - Import error resolution

Environment Fixes

  • Brax memory leak fix (#3052) - Resolved memory leaks in Brax environments
  • Fix behavior of partial, nested dones in PEnv and TEnv (#2959) - Improved done state handling
  • Fix shifted value computation with an LSTM (#2941) - LSTM value computation fixes
  • Fix single action pass to gym when action key is not "action" (#2942) - Action key handling improvements
  • Fix PEnv device copies (#2840) - Device management in parallel environments

Data Management Fixes

  • Fix minari dataloading (#3054) - Resolved Minari dataset loading issues
  • RB.add unsqueezes tds when applying the transform (#3047) - Replay buffer transform handling
  • Fix PRB serialization (#2963) - Prioritized replay buffer serialization
  • Fix lazy-stack in RBs (#2880) - Lazy stacking in replay buffers
  • Keep original class in LazyStackStorage through lazy_stack (#2873) - Class preservation in lazy stacking

Algorithm Fixes

  • Fix deprecated list index (#3005) - Updated deprecated list indexing
  • update_policy_weights_() with cudagraph (#3003) - CUDA graph compatibility
  • Fix compile compatibility of PPO losses (#2889) - Compilation compatibility
  • Fix .item() warning on tensors that require grad (#2885) - Gradient tensor handling
  • Fix KL penalty (#2908) - KL divergence computation fixes

Specification and Type Fixes

  • Fixes the Categorical is_in with non-long integer (#2981) - Type compatibility improvements
  • Categorical spec samples the right dtype when masked (#2980) - Masked categorical sampling
  • Binary can have empty shape (#2979) - Empty shape handling
  • ActionMask is compatible with composite action specs (#3022) - Composite action specification support
  • Fix composite setitem (#2778) - Composite specification item setting

General Fixes

  • Fix various test failures (#2994) - Test suite improvements
  • Fix wrong split_trajectories import (#3023) - Import error resolution
  • Fix typo (#2969) - Documentation typo fixes
  • Fix device in PPO tests (#2971) - Device handling in tests
  • Fix device in args of PPO losses (#2969) - PPO loss device arguments

📚 Documentation

  • Document the LLM env and transform API (#2991) - Comprehensive LLM API documentation
  • Update documentation for _AcceptedKeys in a2c.py (#2987) - A2C documentation improvements
  • WeightUpdaterBase docs update after renaming (#3007) - Updated documentation for renamed components
  • Fix doc pipeline (#2992) - Documentation build improvements
  • Fix Doc (#2919) - General documentation fixes
  • Fix doc setup (#2922) - Documentation setup improvements
  • Better doc for Transform class (#2797) - Transform class documentation
  • Add docstring for MCTSForest.extend (#2795) - MCTS documentation
  • Fix tutorials (#2772, #2768) - Tutorial fixes and improvements

🧪 Testing and Quality

  • Fix wrong import (#3033) - Import error fixes
  • Fix error catches (#2982) - Error handling improvements
  • Fix warnings in tests (#2886) - Warning suppression in tests
  • Test and fix life cycle of env with dynamic non-tensor spec (#2812) - Environment lifecycle testing
  • Capture deprec warnings (#2799) - Deprecation warning handling
  • Fix old deps tests (#2500) - Dependency testing improvements

🔄 Refactoring and Code Quality

  • Refactor the weight update logic (#2914) - Improved weight update architecture
  • Refactor LLM data structures (#2834) - LLM data structure improvements
  • Rename RLHF files to LLM (#2833) - File organization improvements
  • Refactor TransformersWrapper class (#2871) - Transformers wrapper improvements
  • Refactor vLLMWrapper class (#2870) - vLLM wrapper improvements
  • Remove from_vllm and from_hf_transformers (#2874) - Cleanup of deprecated methods
  • Simplify LLMEnv (#2897) - LLM environment simplification

🚀 CI and Infrastructure

  • Fix win CI (#3028) - Windows CI improvements
  • Fix SDL install (#2978) - SDL installation fixes
  • Build wheels on osx 15 (#2934) - macOS 15 compatibility
  • Fix tensordict upper version to 0.9 (#2933) - Dependency version management
  • Fix nightly and benchmark CIs (#2930) - CI pipeline improvements
  • Fix envnames in SOTA tests (#2921) - Test environment naming
  • egl for all (#2915) - EGL support improvements
  • Fix LLM tests (#2918) - LLM test suite fixes
  • Fix old deps (#2916) - Dependency management
  • Upgrade to cuda 12.8 (#2820) - CUDA version upgrade
  • Fix libs workflows (#2800) - Library workflow improvements

📦 Dependencies and Setup

  • Remove distutils imports (#2836) - Modern Python compatibility
  • Fix no_python_abi_suffix error (#2863) - Python ABI suffix handling
  • Upgrade to v0.7 (#2745) - Dependency version updates
  • Fix Cairo-2 Chess import error (#2743) - Chess environment dependencies

🗑️ Deprecations and Removals

  • Enact deprecations (#2917) - Implementation of planned deprecations
  • Remove LLM features for release (#2912) - Temporary removal for release stability
  • Softly change default behavior of auto_unwrap (#2793) - Default behavior changes
  • *Gracing old Spec with v0.8 versioning (#2751) - Specification versioning
  • Remove InPlaceSampler (#2750) - Deprecated sampler removal
  • Remove OrnsteinUhlenbeckProcessWrapper (#2749) - Deprecated wrapper removal
  • Remove AdditiveGaussianWrapper (#2748) - Deprecated wrapper removal
  • Remove NormalParamWrapper (#2747) - Deprecated wrapper removal
  • Change the default MLP depth (#2746) - Default configuration changes

🔧 Minor Improvements

  • Fix sota runs (#3042) - SOTA implementation improvements
  • remove unused variables in GRPO scripts (#3038) - Code cleanup
  • Fix deprecated list index (#3005) - Deprecation warning fixes
  • gitignore ipynb (#2954) - Git ignore improvements
  • Quick edits to .md files (#2931) - Documentation improvements
  • Fix typos in advantages.py (#2492) - Documentation typo fixes
  • Remove redundant return (#2925) - Code cleanup
  • Fix some typos (#2811) - Documentation improvements

📊 Migration Guide

LLM API Usage

The new LLM API provides a complete framework for language model fine-tuning. Key components include:

from torchrl.envs.llm import ChatEnv
from torchrl.modules.llm import TransformersWrapper
from torchrl.objectives.llm import GRPOLoss
from torchrl.collectors.llm import LLMCollector

# Create environment with Python tool execution
env = ChatEnv(
    tokenizer=tokenizer,
    system_prompt="You are an assistant that can execute Python code.",
    batch_size=[1]
).append_transform(PythonInterpreter())

# Wrap your language model
llm = TransformersWrapper(
    model=model,
    tokenizer=tokenizer,
    input_mode="history"
)

# Set up GRPO training
loss_fn = GRPOLoss(llm, critic, gamma=0.99)
collector = LLMCollector(env, llm, frames_per_batch=100)

# Training loop
for data in collector:
    loss = loss_fn(data)
    loss.backward()
    optimizer.step()

Breaking Changes

  • Some deprecated wrappers have been removed (NormalParamWrapper, AdditiveGaussianWrapper, etc.)
  • Default MLP depth has been changed
  • Default behavior of auto_unwrap has been modified

Performance Recommendations

  • Use the new VecNormV2 for improved normalization performance. Can be used through a keyword arg in regular VecNorm transform.
  • Leverage async environments and collectors for better throughput.
  • Consider using RayReplayBuffer for distributed training scenarios.

🙏 Acknowledgments

We would like to thank all contributors who made this release possible, especially those who contributed to the LLM API framework and the comprehensive testing and documentation improvements.


For detailed usage examples and tutorials, please refer to the TorchRL documentation and the LLM API reference.