Skip to content

Commit 97f4a49

Browse files
author
Flax Authors
committed
Merge pull request #4822 from google:add-train-script-gemma-example
PiperOrigin-RevId: 785964844
2 parents db0e302 + 2eb7baa commit 97f4a49

File tree

17 files changed

+2766
-276
lines changed

17 files changed

+2766
-276
lines changed

examples/gemma/README.md

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
2+
## Language modeling
3+
Trains Gemma model on the One Billion Word Benchmark (lm1b; Chelba *et al.*, 2013).
4+
5+
This example is based on `lm1b_nnx` example script and similarly uses linear learning rate warmup and inverse square root learning rate schedule.
6+
7+
8+
### Requirements
9+
10+
* TensorFlow datasets `lm1b` need to be downloaded and prepared (see below).
11+
A sentencepiece tokenizer vocabulary will be automatically generated
12+
and saved on each training run.
13+
* This example additionally depends on the `sentencepiece` and `tensorflow-text` packages.
14+
15+
### Downloading the LM1B Datasets
16+
17+
We recommend downloading and preparing the TFDS datasets beforehand. You can download and prepare LM1B datasets using TFDS directly: `python -m tensorflow_datasets.scripts.download_and_prepare --datasets=lm1b`.
18+
19+
#### Using Cloud Storage FUSE for TPUs
20+
21+
For Cloud TPUs, we recommend using a cheap standard instance and saving the prepared TFDS
22+
data on a storage bucket, from where it can be mounted to the TPU VM using [Cloud Storage FUSE](https://cloud.google.com/storage/docs/cloud-storage-fuse/quickstart-mount-bucket).
23+
24+
##### Copy the preprocessed dataset to the Cloud Storage
25+
26+
We assume that the dataset was downloaded and prepared. We also assume we have configured `gcloud` CLI. The following commands helps to setup the storage and copy the dataset:
27+
28+
```bash
29+
# Install gcsfuse CLI
30+
export GCSFUSE_REPO=gcsfuse-`lsb_release -c -s`
31+
# For example, GCSFUSE_REPO=gcsfuse-noble for Ubuntu 24.04
32+
33+
echo "deb https://packages.cloud.google.com/apt $GCSFUSE_REPO main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list
34+
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
35+
sudo apt-get update
36+
sudo apt-get install -y fuse gcsfuse --no-install-recommends
37+
38+
gcsfuse -v
39+
# gcsfuse version 2.12.2 (Go version go1.24.0)
40+
```
41+
42+
Let's get where LM1B dataset was locally stored:
43+
```bash
44+
python -c "import tensorflow_datasets as tfds; b=tfds.builder('lm1b'); print(b.info.data_dir)"
45+
# For example: /home/user/tensorflow_datasets/lm1b/1.1.0
46+
```
47+
48+
Let's create a GCS bucket for the dataset and link the bucket to a local folder. We choose the bucket name "flax-lm1b-tfdataset" but this can be changed.
49+
```bash
50+
gcloud storage buckets create gs://flax-lm1b-tfdataset
51+
52+
mkdir -p $HOME/data
53+
gcsfuse flax-lm1b-tfdataset $HOME/data
54+
```
55+
56+
Now let's copy the data to the bucket:
57+
```bash
58+
# Let's assume that prepared dataset is at $HOME/tensorflow_datasets/lm1b/
59+
cp -R $HOME/tensorflow_datasets/lm1b $HOME/data
60+
```
61+
62+
##### Setup the dataset on TPU VM
63+
64+
We previously have choosen the bucket name "flax-lm1b-tfdataset" where stored the dataset, adapt this name to your situation.
65+
66+
```bash
67+
# On the TPU VM
68+
gcsfuse flax-lm1b-tfdataset $HOME/tensorflow_datasets
69+
70+
ls $HOME/tensorflow_datasets/lm1b/1.1.0/
71+
```
72+
73+
### How to run on GPU(s)
74+
75+
Install Jax with CUDA support, Flax and the example dependencies with the following command:
76+
```bash
77+
pip install jax[cuda12]
78+
# Check whether GPUs are available:
79+
# python3 -c "import jax; print(jax.devices())"
80+
81+
git clone --depth=1 --branch=main https://github.com/google/flax
82+
cd flax
83+
pip install -e .
84+
cd examples/gemma
85+
pip install -r requirements.txt
86+
```
87+
88+
Start the training:
89+
90+
- train a small transformer model:
91+
```bash
92+
python3 main.py --workdir=$HOME/logs/small_gemma_lm1b --config=configs/small.py
93+
```
94+
95+
- train Gemma3-4B model:
96+
```bash
97+
python3 main.py --workdir=$HOME/logs/gemma3-4b_lm1b --config=configs/gemma3_4b.py
98+
```
99+
100+
To monitor the trainings with the TensorBoard:
101+
```bash
102+
tensorboard --logdir=$HOME/logs
103+
```
104+
105+
106+
### How to run on Cloud TPUs
107+
108+
Setup the TPU VM and install the Flax dependencies on it as described
109+
[here](https://cloud.google.com/tpu/docs/jax-pods) for creating pod slices, or
110+
[here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm) for a single
111+
v4-8 TPU.
112+
113+
114+
First create a single TPUv4-8 VM and connect to it (you can find more detailed
115+
instructions [here](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm)):
116+
117+
```bash
118+
ZONE=us-central1-a
119+
TPU_TYPE=v4-8
120+
TPU_NAME=$USER-flax-gemma-lm1b
121+
gcloud compute tpus tpu-vm create $TPU_NAME \
122+
--zone $ZONE \
123+
--accelerator-type $TPU_TYPE \
124+
--version tpu-ubuntu2204-base
125+
126+
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone $ZONE -- \
127+
-L 6006:localhost:6006
128+
```
129+
130+
When connected install JAX:
131+
132+
```bash
133+
pip install "jax[tpu]>=0.2.16" \
134+
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
135+
```
136+
137+
Then install Flax + the example dependencies:
138+
139+
```bash
140+
git clone --depth=1 --branch=main https://github.com/google/flax
141+
cd flax
142+
pip install -e .
143+
cd examples/gemma
144+
pip install -r requirements.txt
145+
```
146+
147+
In case of errors when installing example dependencies, try to upgrade existing `pip` package and downgrade `setuptools` and repeat the installation command
148+
```bash
149+
# Optionally
150+
# pip install -U pip
151+
# pip install -U "setuptools<70"
152+
# pip install -r requirements.txt
153+
```
154+
155+
And finally start the training:
156+
157+
```bash
158+
python3 main.py --workdir=$HOME/logs/gemma_lm1b_256 --config.per_device_batch_size=32
159+
```
160+
161+
Note that you might want to set `TFDS_DATA_DIR` as explained below. You probably
162+
also want to start the long-running command above in a `tmux` session and start
163+
some monitoring in a separate pane (note that we forwarded port 6006 locally
164+
above):
165+
166+
```bash
167+
tensorboard --logdir=$HOME/logs
168+
```

examples/gemma/configs/default.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Default Hyperparameter configuration."""
16+
17+
import dataclasses
18+
19+
from train import MeshRules, TrainConfig
20+
21+
22+
@dataclasses.dataclass(unsafe_hash=True)
23+
class Config:
24+
# Path to load or store sentencepiece vocab file.
25+
vocab_path: str | None = None
26+
# Vocabulary size if `vocab_path` is not given.
27+
vocab_size: int = 35_000 # lm1b dataset vocab size: 35913 (Gemma expected vocab size: 262_144)
28+
# Maximum number of characters to use for training.
29+
max_corpus_chars: int = 10**7
30+
# Name of TFDS translation dataset to use.
31+
dataset_name: str = 'lm1b'
32+
# Optional name of TFDS translation dataset to use for evaluation.
33+
eval_dataset_name: str = 'lm1b'
34+
# Optional name of TFDS split to use for evaluation.
35+
eval_split: str = 'test'
36+
# Per device batch size for training.
37+
per_device_batch_size: int = 32
38+
# Per device batch size for training.
39+
eval_per_device_batch_size: int = 32
40+
41+
# Prompt for language model sampling
42+
prompts: tuple[str, ...] = (
43+
'Paris is a the capital',
44+
'Flax is a',
45+
# From train set:
46+
'The shutdown was aimed at creating efficiencies as',
47+
# -> the plant was already operating at its maximum capacity of 3,000 tonnes of cellulose paste per day
48+
'A big theme of this hire is that there are parts of',
49+
# -> our operations that to use a pretty trite phrase , need to be taken to the next level ...
50+
51+
# From test set:
52+
'Because of Bear Stearns , many analysts are',
53+
# -> raising the odds that a 2008 recession could be worse than expected
54+
'Next month , the Brazilian bourse',
55+
# -> opens a London office',
56+
)
57+
# Temperature for top_p sampling.
58+
sampling_temperature: float = 0.0
59+
# Top-p sampling threshold.
60+
sampling_top_p: float = 0.95
61+
62+
# Number of steps to take during training.
63+
num_train_steps: int = 500_000
64+
# Number of steps to take during evaluation.
65+
# Large enough to evaluate all samples: 306_688 / (32 * 8) = 1198
66+
num_eval_steps: int = 2_000
67+
# Number of steps to generate predictions.
68+
# -1 will use the whole eval dataset.
69+
num_predict_steps: int = 50
70+
# Base learning rate.
71+
learning_rate: float = 0.0016
72+
# Linear learning rate warmup.
73+
warmup_steps: int = 1000
74+
# Cross entropy loss label smoothing.
75+
label_smoothing: float = 0.0
76+
# Decay factor for AdamW style weight decay.
77+
weight_decay: float = 0.1
78+
# Maximum length cutoff for training examples.
79+
max_target_length: int = 128
80+
# Maximum length cutoff for eval examples.
81+
max_eval_target_length: int = 512
82+
83+
# Gemma transformer name.
84+
# Possible values defined in transformer.TransformerConfig:
85+
# (gemma_2b, gemma_7b, gemma2_2b, gemma2_9b, gemma2_27b, gemma3_1b, gemma3_4b, ...)
86+
transformer_name: str | None = "gemma3_1b"
87+
# or alternatively define the model using the dict of parameters
88+
transformer_params: dict | None = None
89+
90+
# Whether to save model checkpoints.
91+
save_checkpoints: bool = True
92+
# Whether to restore from existing model checkpoints.
93+
restore_checkpoints: bool = True
94+
# Save a checkpoint every these number of steps.
95+
checkpoint_every_steps: int = 10_000
96+
# Frequency of eval during training, e.g. every 1_000 steps.
97+
eval_every_steps: int = 5_000
98+
# Use bfloat16 mixed precision training instead of float32.
99+
use_bfloat16: bool = True
100+
# Integer for PRNG random seed.
101+
seed: int = 0
102+
103+
# Parallelism
104+
mesh_axes: tuple[str, ...] = ('data', 'fsdp', 'tensor')
105+
axis_rules: MeshRules = MeshRules(
106+
embed='fsdp',
107+
mlp='tensor',
108+
kv='tensor',
109+
vocab='tensor',
110+
)
111+
data_sharding: tuple[str, ...] = ('data', 'fsdp')
112+
113+
# One axis for each parallelism type may hold a placeholder (-1)
114+
# value to auto-shard based on available slices and devices.
115+
# By default, product of the DCN axes should equal number of slices
116+
# and product of the ICI axes should equal number of devices per slice.
117+
# ICI (Inter-Chip Interconnection): A high-speed connection between
118+
# sets of TPU chips, which form the TPU network.
119+
# DCN (Data Center Network): A connection between the TPU networks;
120+
# not as fast as ICI.
121+
# ICI has around 100x the bandwidth of DCN, but it is not a general
122+
# purpose connection, which is why DCN is necessary for scaling to
123+
# extremely large ML models.
124+
dcn_data_parallelism: int = -1
125+
dcn_fsdp_parallelism: int = 1
126+
dcn_tensor_parallelism: int = 1
127+
ici_data_parallelism: int = 1
128+
ici_fsdp_parallelism: int = -1
129+
ici_tensor_parallelism: int = 1
130+
131+
132+
def get_config() -> TrainConfig:
133+
"""Get the default hyperparameter configuration."""
134+
config = Config()
135+
return TrainConfig(**dataclasses.asdict(config))

0 commit comments

Comments
 (0)