Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set TF_FORCE_GPU_ALLOW_GROWTH=true by default #712

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

samos123
Copy link
Contributor

@samos123 samos123 commented Sep 24, 2024

This is needed to be able to run Fuji v2 70B on GPU without GPU memory OOMs.

@kelvin-zou can likely confirm whether this should be the default or not.

This is needed to be able to run Fuji v2 70B on GPU without GPU memory
OOMs.
Copy link
Contributor

@kelvin-zou kelvin-zou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Can we move it to somewhere since it is GPU specific? This launch cmd is shared across GPU and TPU.

@samos123
Copy link
Contributor Author

samos123 commented Sep 24, 2024

Hmm but we also set a lot of TPU environment variables in launch.py without any if statements. I don't think there is a better place since it needs to happen before jax is started?

Would you prefer this?

if instance_type.startswith("gpu"):
    # Prevent GPU OOM issues due to TF taking up all the GPU memory.
    # Reference: https://stackoverflow.com/a/54927279
    os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")

@samos123 samos123 requested a review from kelvin-zou September 24, 2024 22:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants