Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Suggestion for Fine-Grained Batch Control e.g per_device_train_batch_size or mini_batch_size #80

Open
dechunwang opened this issue Oct 6, 2024 · 1 comment

Comments

@dechunwang
Copy link

Hello there,
First, I'd like to express my appreciation for your excellent work on this project.
While experimenting with PPO/RW using this repository, I consistently encounter Out of Memory (OOM) errors with the following configuration:

  • Model: LLaMA 7B
  • Tensor Parallelism (TP): 1
  • Pipeline Parallelism (PP): 8
  • Global Batch Size (GBS): 256
  • Sequence Length: 8192

This error is unexpected given the relatively small model size and parallelism configuration.

The project currently offers an n_mbs setting, which splits data batches into n_mbs chunks. However, this approach has limitations:

  1. Difficulty in determining the exact batch size after data packing
  2. To set n_mbs correctly without knowing globale batch size is very hard

After reviewing the documentation, I believe a crucial feature is missing:
per_device_train_batch_size or mini_batch_size

A direct mini_batch_size setting would provide more intuitive and precise control over batch sizes across different parallelism configurations.

This setting would allow users to specify the mini-batch size for each Data Parallel (DP) rank, providing several benefits:

  1. Fine-grained control in MP/PP/TP scenarios (e.g., virtual parallel)
  2. Better resource utilization (e.g., a 7B parameter model shouldn't require PP > 2 if mini_batch_size = 1)

Would it be possible to consider adding this feature in a future update?

@dechunwang dechunwang changed the title add mini batch size setting? Suggestion for Fine-Grained Batch Control e.g per_device_train_batch_size or mini_batch_size Oct 6, 2024
@garrett4wade
Copy link
Contributor

Hi dechun, sorry for the late reply.

Your suggestion is quite reasonable. I will raise a PR later. You can also make the change if you would like to contribute to this project.

I'd like to make some comments more about your scenario. Some tips here:

  • Set "ppo.gen.force_no_logits_mask=True", otherwise the mask will occupy a huge amount of GPU memory but turns out to have minor learning performance improvement.

  • Use the identical parallel strategy if your resource budget is tight (i.e., allocation_mode="d1m8p1", I believe TP is more memory-efficient than PP though). This will disable parameter reallocation but still enable offloading.

  • For now, calculate n_mbs based on per_device_train_batch_size=1. For PPO training, the global batch size is per_device_train_batch_size * (pp_size * 2) * dp_size * ppo_n_mbs * n_mbs. For generation or inference, it is per_device_train_batch_size * pp_size * dp_size * n_mbs. Though inconvenient, I think setting per_device_train_batch_size=1 implicitly will fix the OOM issue.

If OOM still happens, it would be helpful if you can share the log and CLI configuration. Discussions are welcomed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants