Skip to content

[RFC] validation and evaluation in torchtitan #1210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
tianyu-l opened this issue May 20, 2025 · 4 comments
Open

[RFC] validation and evaluation in torchtitan #1210

tianyu-l opened this issue May 20, 2025 · 4 comments

Comments

@tianyu-l
Copy link
Contributor

tianyu-l commented May 20, 2025

Recently I’ve been trying to understand the evaluation tasks, which is currently missing in torchtitan and requested in multiple issues & PRs. This RFC aims to organize thoughts, enumerate options, and propose priorities.

Overall, “evaluation” seems to be an overloaded term, which sometimes also denotes validation.

Validation

It usually happens in-training, on a small to moderate size dataset, with the same / similar metrics as training loss. The purpose is often to tune hyperparameters and to monitor overfitting. In some cases such as today’s LLM training, it seems to me that the use case should be less important – a model is almost always trained on unseen data (whose amount is usually instructed by scaling law experiments), so observing the training loss alone would be as informative as the validation loss, as long as the training data is randomly distributed with low variance. Please correct me if I’m wrong.

In some other cases, it is still important. E.g. in diffusion model training, the input latent is a mixture of the target and noise according to a random timestamp, increasing the training loss variance and making it not a proper indicator of the training progress. Therefore, validation with a fixed timestamp (and fixed data) is crucial.

Evaluation

Evaluation is about the assessment of the trained model's performance, using various downstream metrics, which generalizes the concept of the “test loss” in traditional ML. Evaluation metrics would be the final metrics we see on papers and technical reports. Technically, this is a one-time process and we don’t tune models based on evaluation metrics. But I guess in today’s large model training this is not strict any more.

A major scenario for evaluation should be after-training, on a trained model checkpoint (see user requests #758 #693 (comment)), although in-training evaluation support is also important to users (see comments #883 (comment) #693 (reply in thread))


Proposed solutions (ordered by priority)

In-training validation

We can follow standard practices of [solution 1] running the forward pass only periodically on a validation dataset. Examples include

We can simply add an entry for building validation dataloader in torchtitan TrainSpec. A more flexible extension point would be a general “Validator” similar to a MetricsProcessor.

We need to make sure our parallelized model can run forward-only passes. E.g. PP needs to support eval_step as in pytorch/pytorch#153956.

After-training evaluation

EleutherAI’s lm-evaluation-harness (or lm_eval) is the commonly used framework for doing evaluation for language models. There are two reasonable approaches to integrate torchtitan with lm_eval.

HF has good integration with various open source toolings, including other evaluation frameworks (#693 (reply in thread)) and inference engines (#693 (reply in thread)). [solution 2] Converting DCP to HF has been asked in torchtitan for almost a year (#420), and we should prioritize this.

The limitation of this approach is around scaling. It seems HF doesn’t support the full sets of n-D parallelisms as in torchtitan, so it could be possible that one can train a model in torchtitan but cannot evaluate it (efficiently) in HF. Besides, IIUC HF’s safetensors format is not as scalable as DCP, and the conversion between DCP and HF could take a very long time for large models.

The other option is to [solution 3] integrate lm_eval with torchtitan, according to the lm_eval guide on “External Library Usage” (https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage). We need to implement wrapper classes around torchtitan models with certain interfaces required by lm_eval, which would require some effort. The benefit would be full n-D parallelism support for evaluation. E.g. see user request #693 (comment).

Examples include

In-training evaluation

The additional benefit of integrating lm_eval with torchtitan would be that it not only supports after-training eval, but also in-training eval. E.g. see gpt-neox’s integration into the Megatron trainer:
https://github.com/EleutherAI/gpt-neox/blob/main/megatron/training.py#L1650


Thanks for the discussions @casper-hansen @ebsmothers @eminorhan @janEbert @K-H-Ismail @samsja @wwwjn. Any feedback is welcome!

@samsja
Copy link
Contributor

samsja commented May 20, 2025

I agree with all the points cited above.

I would add that in-training eval might not be that important tho. Especially on large scale pertaining its unlikely that we would want to pause the whole training do to eval on 500+ gpus. We would rather just push the checkpoint and do the eval on a smaller cluster (or even just one node). In this context, it would not matter to much if the dpc -> hf safetensor conversion is slow, as the eval would be relatively async.

So IMO having a torchtitan into hf conversion script (that is tested for correctness) would be "good enough " for most eval job without requiring too much work / change on torchtitan side

Thanks for the RFC !

@fegin
Copy link
Contributor

fegin commented May 20, 2025

For In-training validation, it looks like solution 1 is straightforward and common solution though some designs are needed to make it fit into TorchTitan's trainer architecture.

As for After-training evaluation, if we choose to do solution 2, "DCP to HF" is not needed. Instead, we can directly use get_state_dict to gather the full state_dict and call HF filesystem to save the full state_dict to safetensors. However, what I don't know is that do these HF frameworks require HF model definitions or can we use the model definitions from TorchTitan? This makes the biggest difference because if we have to use HF model definitions, users have to provide how to convert the state_dict (e.g., "model.layer1.weight" to "model.ffn.weight"), which is not a trivial job.

@tianyu-l
Copy link
Contributor Author

@fegin
Good point. Indeed, in solution 2 there are two steps, one is to convert DCP weights into safetensors, the other is model parameters name conversion.

To answer your question

However, what I don't know is that do these HF frameworks require HF model definitions or can we use the model definitions from TorchTitan?

I checked the code of lm_eval, without full conversion to a HF model, one can't use the accelerate library and hence cannot to multi-gpu inference.
https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/models/huggingface.py#L100

if we have to use HF model definitions, users have to provide how to convert the state_dict (e.g., "model.layer1.weight" to "model.ffn.weight"), which is not a trivial job.

For name conversion, torchtune has _convert_weights.py associated with each model
https://github.com/pytorch/torchtune/blob/main/torchtune/models/llama4/_convert_weights.py

@fegin
Copy link
Contributor

fegin commented May 21, 2025

If it is just one to one mapping, then that's not too bad. But as you can see in the file, there are more than just name conversion. The model author will have to provide this conversion as well. But I guess there is no way to avoid this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants