Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

32 b #121

Draft
wants to merge 143 commits into
base: main
Choose a base branch
from
Draft

32 b #121

wants to merge 143 commits into from

Conversation

dirkgr
Copy link
Member

@dirkgr dirkgr commented Dec 10, 2024

No description provided.

@dirkgr dirkgr requested a review from epwalsh December 10, 2024 01:06
@@ -130,7 +130,7 @@ def build(self, trainer: "Trainer") -> Optional[Callback]:
eval_batch_size = (
self.eval_batch_size
if self.eval_batch_size is not None
else trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group)
else 2 * trainer.rank_microbatch_size * get_world_size(trainer.dp_process_group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could instead passed an updated evaluator callback in OLMo2-32B.py:

.with_callback(
    "lm_evaluator",
    LMEvaluatorCallbackConfig(
        eval_batch_size=<whatever you want>,
        eval_dataset=NumpyDatasetConfig.from_data_mix(
            DataMix.v3_small_ppl_validation,
            name=NumpyDatasetType.padded_fsl,
            mix_base_dir=root_dir,
            sequence_length=dataset_config.effective_sequence_length,
            tokenizer=tokenizer_config,
            work_dir=get_work_dir(root_dir),
        ),
        eval_interval=1000,
    ),

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but I think this is better. I think we can default to 2x the training batch size. It should always work.

# import flash_attn.ops.triton.cross_entropy as flash_attn_ce # type: ignore

_fused_cross_entropy_loss = triton_ce_loss.cross_entropy_loss
import flash_attn.ops.triton.cross_entropy as flash_attn_ce # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our in-house triton CE loss was copied directly from the flash-attn repo, so I don't see the point of this.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I took this back out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I want compiling and fused loss at the same time?

"""
d_model = 5120
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a very narrow model then... are you sure about that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a clone of Qwen 32. The tradeoffs are, narrow d_model, wide FFN, GQA, lots of layers.

Comment on lines 65 to 66
fused_loss=True,
compile_loss=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the trepidation about the different loss implementations, but the way it was before was the most performant. This way will be slower and have a higher memory footprint.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have some certainty that this will do the right thing? What happens if we take the 13B from a late checkpoint and run it?

enabled=False,
cancel_check_interval=10,
),
).with_callback(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should just add this to the common callbacks.

"lm_evaluator": LMEvaluatorCallbackConfig(

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know that we want these for everything. Default should probably be only the new, blessed ones.

@@ -590,7 +594,7 @@ def _gcs_get_bytes_range(bucket_name: str, key: str, bytes_start: int, num_bytes
)


@retriable()
@retriable(retry_condition=_gcs_is_retriable)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This general approach sort of blows up our retry time from 10 mins to 30 mins. Sort of not a fan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But at least it looks like it works.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could always reduce the deadline/timeout

epwalsh added a commit that referenced this pull request Jan 21, 2025
This PR pulls the general important changes in from #121.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants