Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ repos:
rev: 0.6.1
hooks:
- id: nbstripout
exclude: ^examples/.*
args: [
--keep-output,
--keep-count,
Expand Down
73 changes: 49 additions & 24 deletions examples/sst2/sst2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
"\n",
"Demonstration notebook for\n",
"https://github.com/google/flax/tree/main/examples/sst2"
]
],
"id": "29fb3c7c"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"**Before you start:** Select Runtime -> Change runtime type -> GPU."
]
],
"id": "5b526b04"
},
{
"cell_type": "markdown",
Expand All @@ -39,14 +41,16 @@
" `train.py`.\n",
"4. At any time, feel free to paste code from `train.py` into the notebook\n",
" and modify it directly there!"
]
],
"id": "89ec78c1"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
],
"id": "7e4ba0dc"
},
{
"cell_type": "code",
Expand All @@ -56,7 +60,8 @@
"source": [
"example_directory = 'examples/sst2'\n",
"editor_relpaths = ('configs/default.py', 'train.py', 'models.py')"
]
],
"id": "ee8021b9"
},
{
"cell_type": "code",
Expand Down Expand Up @@ -85,6 +90,7 @@
" os.chdir('/content')\n",
" # Download Flax repo from Github.\n",
" if not os.path.isdir('flaxrepo'):\n",
" pass\n",
" !git clone --depth=1 https://github.com/google/flax flaxrepo\n",
" # Copy example files & change directory.\n",
" mount_gdrive = 'no' #@param ['yes', 'no']\n",
Expand All @@ -109,7 +115,8 @@
" open(f'{example_root_path}/{relpath}', 'w').write(\n",
" f'## {DISCLAIMER}\\n' + '#' * (len(DISCLAIMER) + 3) + '\\n\\n' + s)\n",
" files.view(f'{example_root_path}/{relpath}')"
]
],
"id": "36dab290"
},
{
"cell_type": "code",
Expand All @@ -119,7 +126,8 @@
"source": [
"# Note: In Colab, above cell changed the working directory.\n",
"!pwd"
]
],
"id": "700d9428"
},
{
"cell_type": "code",
Expand All @@ -129,14 +137,16 @@
"source": [
"# Install SST-2 dependencies.\n",
"!pip install -q -r requirements.txt"
]
],
"id": "2fbc3e64"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports / Helpers"
]
],
"id": "e40c50cf"
},
{
"cell_type": "code",
Expand All @@ -153,7 +163,8 @@
" import os\n",
" os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n",
"jax.devices()"
]
],
"id": "703f04fb"
},
{
"cell_type": "code",
Expand All @@ -172,7 +183,8 @@
"\n",
"# Make sure the GPU is for JAX, not for TF.\n",
"tf.config.experimental.set_visible_devices([], 'GPU')"
]
],
"id": "32e12a97"
},
{
"cell_type": "code",
Expand All @@ -186,20 +198,26 @@
"# Any changes you make to train.py will appear automatically.\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"import train\n",
"import models\n",
"import vocabulary\n",
"import input_pipeline\n",
"from configs import default as config_lib\n",
"try:\n",
" import train\n",
" import models\n",
" import vocabulary\n",
" import input_pipeline\n",
" from configs import default as config_lib\n",
"except ModuleNotFoundError:\n",
" # Local imports may not be available in all contexts\n",
" pass\n",
"config = config_lib.get_config()"
]
],
"id": "94ece24d"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset"
]
],
"id": "c8a6ec00"
},
{
"cell_type": "code",
Expand All @@ -213,14 +231,16 @@
"# If you get an error you need to install tensorflow_datasets from Github.\n",
"train_dataset = input_pipeline.TextDataset(split='train')\n",
"eval_dataset = input_pipeline.TextDataset(split='validation')"
]
],
"id": "5c615ca6"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training"
]
],
"id": "7d7c55cd"
},
{
"cell_type": "code",
Expand All @@ -231,9 +251,11 @@
"# Get a live update during training - use the \"refresh\" button!\n",
"# (In Jupyter[lab] start \"tensorboard\" in the local directory instead.)\n",
"if 'google.colab' in str(get_ipython()):\n",
" pass\n",
" %load_ext tensorboard\n",
" %tensorboard --logdir=."
]
],
"id": "df9d52ed"
},
{
"cell_type": "code",
Expand All @@ -248,7 +270,8 @@
"start_time = time.time()\n",
"optimizer = train.train_and_evaluate(config, workdir=f'./models/{model_name}')\n",
"logging.info('Walltime: %f s', time.time() - start_time)"
]
],
"id": "f159d072"
},
{
"cell_type": "code",
Expand All @@ -265,13 +288,15 @@
" #@markdown Note that everbody with the link will be able to see the data.\n",
" upload_data = 'yes' #@param ['yes', 'no']\n",
" if upload_data == 'yes':\n",
" pass\n",
" !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/mnist'"
]
],
"id": "b8e35c72"
}
],
"metadata": {
"accelerator": "GPU"
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading