diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py
old mode 100755
new mode 100644
index aa23b638f0..307507ad1d
--- a/docs/examples/te_llama/te_llama.py
+++ b/docs/examples/te_llama/te_llama.py
@@ -100,13 +100,21 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k
subfolder = ""
variant = None
if os.path.isfile(
- os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
+ os.path.join(pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant))
):
- # Load from a sharded PyTorch checkpoint
- archive_file = os.path.join(
- pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
- )
- is_sharded = True
+ # Load from a sharded PyTorch checkpoint
+ archive_file = os.path.join(
+ pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant)
+ )
+ is_sharded = True
+ elif os.path.isfile(
+ os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
+ ):
+ # Load from a sharded PyTorch checkpoint
+ archive_file = os.path.join(
+ pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
+ )
+ is_sharded = True
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
index cc77b484f9..57c1bf6601 100755
--- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
+++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
@@ -2,23 +2,23 @@
"cells": [
{
"cell_type": "markdown",
- "id": "2cac9d39",
+ "id": "6a5b2993",
"metadata": {},
"source": [
- "# Accelerating a Hugging Face Llama 2 model with Transformer Engine\n",
+ "# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n",
"\n",
"
\n",
"\n",
"Goal\n",
"\n",
- "This tutorial showcases how to accelerate finetuning a full Llama 2 model from [Hugging Face](https://huggingface.co/meta-llama/Llama-2-7b-hf) by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n",
+ "This tutorial showcases how to accelerate finetuning a full [Llama 2](https://huggingface.co/meta-llama/Llama-2-7b-hf) or [Llama 3](https://huggingface.co/meta-llama/Meta-Llama-3-8B) models from Hugging Face by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n",
"\n",
"
\n"
]
},
{
"cell_type": "markdown",
- "id": "401f7fb1",
+ "id": "331f476a",
"metadata": {},
"source": [
"## Dependencies for this tutorial\n",
@@ -26,16 +26,28 @@
"Following files and media are necessary to effectively run this tutorial:\n",
"\n",
"1. `te_llama.py`\n",
- " - This file contains the code to load a Hugging Face Llama 2 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n",
+ " - This file contains the code to load a Hugging Face Llama 2 or Llama 3 checkpoint in Transformer Engine's `TransformerLayer` instead of Hugging Face's `LlamaDecoderLayer`. This is used in the following two sections of the tutorial - \"Improvement 1\" and \"Improvement 2\".\n",
"2. `utils.py`\n",
" - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training and other miscellaneous tasks like restarting the jupyter notebook from within the cell. \n",
"3. `media/`\n",
- " - This directory contains the images used in the following tutorial."
+ " - This directory contains the images used in the following tutorial.\n",
+ "\n",
+ "These packages are necessary to run this tutorial:\n",
+ "`pytorch`, `transformer_engine`, `accelerate`, `transformers`, `peft`, `datasets`.\n",
+ "\n",
+ "\n",
+ "
\n",
+ "\n",
+ "Note on running the tutorial with Llama 3 weights\n",
+ "\n",
+ "This tutorial shows the cell outputs when run with Llama 2 7B weights. It can be run with Llama 3 8B weights simply by providing the directory with those weights (in Hugging Face format) instead of Llama 2 7B weights. These two models are almost identical, the biggest difference being the model dimension (the smallest Llama 3 model has 8B parameters, whereas the smallest Llama 2 has 7B), which enables this tutorial to work for both of them.\n",
+ "\n",
+ "
\n"
]
},
{
"cell_type": "markdown",
- "id": "33bdb5fe",
+ "id": "44abae4f",
"metadata": {},
"source": [
"## Table of contents\n",
@@ -53,7 +65,7 @@
},
{
"cell_type": "markdown",
- "id": "7645f176",
+ "id": "e37e2cc1",
"metadata": {},
"source": [
"## From \"Transformer\" to \"Llama\" \n",
@@ -67,10 +79,13 @@
"\n",
"- 2017: [\"Attention Is All You Need\"](https://arxiv.org/abs/1706.03762) paper introduced pioneering \"Transformer\" architecture and changed the NLP field forever.\n",
"- 2018-2020: Emergence of GPT model series that showed causal decoder architectures are great fit for pretraining, few-shot and zero-shot learning.\n",
- "- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases. \n",
- "- One of the latest in this line of pretrained models which is also open source is Meta's [Llama 2](https://llama.meta.com/llama2) models (Large Language Model Meta AI). \n",
- " - These models range from 7B to 65B parameters.\n",
+ "- Fast forward to 2023-2024: Following GPT-3/GPT-4 success stories, researchers and companies raced to produce the next best pretrained model that could further be finetuned for application-specific use-cases.\n",
+ "- February 2023: Meta releases [Llama 2](https://llama.meta.com/llama2) models (Large Language Model Meta AI). \n",
+ " - These models range from 7B to 70B parameters.\n",
" - LLaMA 2 was pretrained on 2 trillion tokens.\n",
+ "- April 2024: Meta releases [Llama 3](https://llama.meta.com/llama3) models.\n",
+ " - These models range from 8B to 70B parameters.\n",
+ " - LLaMA 3 was pretrained on 15 trillion tokens.\n",
"\n",
"For more information on Llama 2 consider reading the [Huggingface tutorial](https://huggingface.co/blog/llama2). As a quick summary, here are some of the important differences b/w the conventional transformer decoder architecture vs Llama 2 architecture:\n",
"\n",
@@ -78,9 +93,16 @@
"2. RMSNorm in place of the LayerNorm\n",
"3. SwiGLU activation function\n",
"4. RoPE as positional embeddings \n",
- "5. Grouped Query Attention\n",
+ "5. Grouped Query Attention for the 70B model\n",
"6. Trained on 4K context length\n",
"\n",
+ "Hugging Face also released a [tutorial about Llama 3](https://huggingface.co/blog/llama3). The key points are:\n",
+ "\n",
+ "1. Use of bigger tokenizer - 128256 vs 32K.\n",
+ "2. Grouped Query Attention is used also by smaller 8B model.\n",
+ "3. The context length increased to 8K for all models.\n",
+ "3. Llama 3 was trained on 8x more data than Llama 2.\n",
+ "\n",
"