Skip to content

non parallelized basic validator implementation [WIP] #1362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

wesleytruong
Copy link
Contributor

The purpose of this PR is to create a basic, non-parallelized validator implementation and to get feedback on code structure and cleanliness.
Changes:

  • Created validation section in job_config with
  • Created a builder function for validator in train_spec
  • Created a builder function for validation dataset in hf_dataset.py
  • Created validator class
    • Validator class initializes a build_validation_hf_loader but leaves this dataloader function unexposed to the train_spec
  • Integrated validation call into training loop
  • Creates one simple integration test with no parallelization and NGPU=1

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jul 2, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass looks really good!
I left many detailed comments, please see if they make sense.

):
self.job_config = job_config
self.loss_fn = loss_fn
self.model = model
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should pass model (model_parts) as an arg to validate, because it's changing

Comment on lines 42 to 47
job_config: JobConfig,
loss_fn: LossFunction,
model: nn.Module,
dp_world_size: int,
dp_rank: int,
tokenizer: Tokenizer,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's make the order as close as how you used them below in build_hf_validation_dataloader


seq_len: int = 2048
"""Sequence length for validation"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set up a steps config, controlling how many iterations we run, default to -1 which means consuming all the data in the validation dataset

Comment on lines 53 to 56
# path="tests/assets/c4_test",
# loader=lambda path: load_dataset(path, split="validation"),
# text_processor=_process_c4_text,
# ),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should use path="allenai/c4", and loader=lambda path: load_dataset(path, name="en", split="validation"),

@@ -319,6 +321,23 @@ def __init__(self, job_config: JobConfig):
device_type,
)

# Build validator if validation is configured
self.validator = None
if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if job_config.validation.enabled:
  assert self.train_spec.build_validator_fn is not None
  # build validator ...

@@ -319,6 +321,23 @@ def __init__(self, job_config: JobConfig):
device_type,
)

# Build validator if validation is configured
self.validator = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this line, since it's already defined as instance variable

for k, v in input_dict.items():
if isinstance(v, torch.Tensor):
input_dict[k] = v.to(device_type)
if isinstance(labels, torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this if?

Comment on lines 70 to 71
for batch_data, targets in self.validation_dataloader:
input_dict, labels = batch_data, targets
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for batch_data, targets in self.validation_dataloader:
input_dict, labels = batch_data, targets
for input_dict, labels in self.validation_dataloader:

logger.warning("No validation batches processed")

# Set model back to train mode
self.model.train()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put this as the last line of this method

@wesleytruong
Copy link
Contributor Author

I've cleaned up the code according to your comments and added support for the validation frequency and steps. I also left streaming=True in the c4_validation dataset since otherwise it downloads the entire training dataset too. @tianyu-l

seq_len: int = 2048
"""Sequence length for validation"""

val_freq: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to have the val_ prefix as it's not ambiguous under Validation

Suggested change
val_freq: int = 1
freq: int = 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe default to 10

"""Frequency of validation"""

val_steps: int = -1
"""Number of validation steps, -1 means all steps"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Number of validation steps, -1 means all steps"""
"""Number of validation steps, -1 means consuming all the data in the validation dataset"""

dp_rank: int,
tokenizer: Tokenizer,
job_config: JobConfig,
infinite: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this arg -- I don't think anyone wants to do multiple loops over the validation dataset

seq_len=seq_len,
dp_rank=dp_rank,
dp_world_size=dp_world_size,
infinite=infinite,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you can always set to False here

@@ -54,6 +54,7 @@ tensor_parallel_degree = 1
enable_async_tensor_parallel = false
pipeline_parallel_degree = 1
context_parallel_degree = 1
disable_loss_parallel = true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

revert this change?

@@ -463,6 +477,12 @@ def train_step(
else:
global_avg_loss = global_max_loss = loss.detach().item()

# Run validation if validator is available
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as this is not part of training step, let's put this outside train_step and put it in train before self.checkpointer.save(...)

"--validation.dataset c4_test",
],
],
"Validation test no parallelism",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this is not without parallelism -- you are doing data parallel for validation; however, you are not doing all-reduce on the loss, so the loss you print out would be different on each DP rank. Let's do that in this PR, following the code in model forward.
https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L451-L464

For that you'll need to pass in parallel_dims world_mesh ft_manager when constructing Validator

I think then the code will support Tensor Parallel and Context Parallel but not Pipeline Parallel yet, which we can do in a followup PR.

model_parts: list[nn.Module],
) -> dict[str, float]:
# Set model to eval mode
model = model_parts[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO: here claiming we only support data parallel for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to not support all parallelisms besides PP here?

num_val_steps = 0

with torch.no_grad():
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe you don't need this try-catch because StopIteration will be automatically captured by for loop safely.

Copy link
Contributor

@runame runame left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for implementing this, this will be very useful!

You can take a look at these changes for some inspiration for addressing some of my comments.

if self.job_config.validation.enabled and self.validator.should_validate(
self.step
):
validation_metrics = self.validator.validate(self.model_parts)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation metrics should be logged by self.metrics_processor.log() (to the terminal output and Tensorboard/wandb).

# Build validator if validation is configured
if job_config.validation.enabled:
assert self.train_spec.build_validator_fn is not None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you raise an error here if parallel_dims.pp_enabled?

@@ -49,6 +49,13 @@ class DatasetConfig:
loader=lambda path: load_dataset(path, split="train"),
text_processor=_process_c4_text,
),
"c4_validation": DatasetConfig(
path="allenai/c4",
loader=lambda path: load_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you can reuse _load_c4_dataset together with functools.partial here by adding split as an argument to _load_c4_dataset.

@@ -193,3 +200,34 @@ def build_hf_dataloader(
dp_world_size=dp_world_size,
batch_size=batch_size,
)


def build_hf_validation_dataloader(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think adding a new function for this is necessary; I would prefer replacing the job_config argument with dataset_name, dataset_path, batch_size, and seq_len. The reasoning is that for validation the function is also just returning a data loader based on a HF dataset, just the underlying dataset will be different.

@@ -657,6 +657,30 @@ class Experimental:
"""


@dataclass
class Validation:
enabled: bool = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could remove this field and modify val_freq to offer an option for disabling validation, e.g., val_freq: int | None = 10, where validation is disabled if val_freq=None.

# Compute average loss
if num_batches > 0:
average_loss = total_loss / num_batches
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this code path should never be used, you could guarantee this (ignoring the case of an empty dataloader) by adding a __post_init__ to the Validation dataclass that verifies that all values are valid, e.g., val_steps > 0.

# Set model back to train mode
model.train()

return {"validation_loss": average_loss}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The average_loss is the local loss for each rank, but should still be all-reduced across ranks.

# Set model back to train mode
model.train()

return {"validation_loss": average_loss}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change this to "validation/loss"? This is important for how wandb represents the metrics and allows you to add more metrics to the same section via "validation/<you-new-metric>" later on.

total_loss += loss.item()
num_batches += 1

num_val_steps += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you use separate counters for num_batches and num_val_steps? Also, you could use this instead:

for step, (input_dict, labels) in enumerate(self.validation_dataloader):

Here, step replaces num_batches and num_val_steps. You would also have to change num_val_steps >= self.job_config.validation.val_steps to step > self.job_config.validation.val_steps above.

device_type = utils.device_type
num_val_steps = 0

with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: you can also use this as a decorator instead, so you don't have to indent your code as much.

@torch.no_grad()
def validate(

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants