Skip to content

Commit 5a3d148

Browse files
danielsuoFlax Authors
authored andcommitted
[flax:examples:sst2] Fix notebook error.
PiperOrigin-RevId: 839449129
1 parent a5af24c commit 5a3d148

File tree

2 files changed

+50
-24
lines changed

2 files changed

+50
-24
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ repos:
2828
rev: 0.6.1
2929
hooks:
3030
- id: nbstripout
31+
exclude: ^examples/.*
3132
args: [
3233
--keep-output,
3334
--keep-count,

examples/sst2/sst2.ipynb

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
"\n",
1111
"Demonstration notebook for\n",
1212
"https://github.com/google/flax/tree/main/examples/sst2"
13-
]
13+
],
14+
"id": "29fb3c7c"
1415
},
1516
{
1617
"cell_type": "markdown",
1718
"metadata": {},
1819
"source": [
1920
"**Before you start:** Select Runtime -> Change runtime type -> GPU."
20-
]
21+
],
22+
"id": "5b526b04"
2123
},
2224
{
2325
"cell_type": "markdown",
@@ -39,14 +41,16 @@
3941
" `train.py`.\n",
4042
"4. At any time, feel free to paste code from `train.py` into the notebook\n",
4143
" and modify it directly there!"
42-
]
44+
],
45+
"id": "89ec78c1"
4346
},
4447
{
4548
"cell_type": "markdown",
4649
"metadata": {},
4750
"source": [
4851
"## Setup"
49-
]
52+
],
53+
"id": "7e4ba0dc"
5054
},
5155
{
5256
"cell_type": "code",
@@ -56,7 +60,8 @@
5660
"source": [
5761
"example_directory = 'examples/sst2'\n",
5862
"editor_relpaths = ('configs/default.py', 'train.py', 'models.py')"
59-
]
63+
],
64+
"id": "ee8021b9"
6065
},
6166
{
6267
"cell_type": "code",
@@ -85,6 +90,7 @@
8590
" os.chdir('/content')\n",
8691
" # Download Flax repo from Github.\n",
8792
" if not os.path.isdir('flaxrepo'):\n",
93+
" pass\n",
8894
" !git clone --depth=1 https://github.com/google/flax flaxrepo\n",
8995
" # Copy example files & change directory.\n",
9096
" mount_gdrive = 'no' #@param ['yes', 'no']\n",
@@ -109,7 +115,8 @@
109115
" open(f'{example_root_path}/{relpath}', 'w').write(\n",
110116
" f'## {DISCLAIMER}\\n' + '#' * (len(DISCLAIMER) + 3) + '\\n\\n' + s)\n",
111117
" files.view(f'{example_root_path}/{relpath}')"
112-
]
118+
],
119+
"id": "36dab290"
113120
},
114121
{
115122
"cell_type": "code",
@@ -119,7 +126,8 @@
119126
"source": [
120127
"# Note: In Colab, above cell changed the working directory.\n",
121128
"!pwd"
122-
]
129+
],
130+
"id": "700d9428"
123131
},
124132
{
125133
"cell_type": "code",
@@ -129,14 +137,16 @@
129137
"source": [
130138
"# Install SST-2 dependencies.\n",
131139
"!pip install -q -r requirements.txt"
132-
]
140+
],
141+
"id": "2fbc3e64"
133142
},
134143
{
135144
"cell_type": "markdown",
136145
"metadata": {},
137146
"source": [
138147
"## Imports / Helpers"
139-
]
148+
],
149+
"id": "e40c50cf"
140150
},
141151
{
142152
"cell_type": "code",
@@ -153,7 +163,8 @@
153163
" import os\n",
154164
" os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n",
155165
"jax.devices()"
156-
]
166+
],
167+
"id": "703f04fb"
157168
},
158169
{
159170
"cell_type": "code",
@@ -172,7 +183,8 @@
172183
"\n",
173184
"# Make sure the GPU is for JAX, not for TF.\n",
174185
"tf.config.experimental.set_visible_devices([], 'GPU')"
175-
]
186+
],
187+
"id": "32e12a97"
176188
},
177189
{
178190
"cell_type": "code",
@@ -186,20 +198,26 @@
186198
"# Any changes you make to train.py will appear automatically.\n",
187199
"%load_ext autoreload\n",
188200
"%autoreload 2\n",
189-
"import train\n",
190-
"import models\n",
191-
"import vocabulary\n",
192-
"import input_pipeline\n",
193-
"from configs import default as config_lib\n",
201+
"try:\n",
202+
" import train\n",
203+
" import models\n",
204+
" import vocabulary\n",
205+
" import input_pipeline\n",
206+
" from configs import default as config_lib\n",
207+
"except ModuleNotFoundError:\n",
208+
" # Local imports may not be available in all contexts\n",
209+
" pass\n",
194210
"config = config_lib.get_config()"
195-
]
211+
],
212+
"id": "94ece24d"
196213
},
197214
{
198215
"cell_type": "markdown",
199216
"metadata": {},
200217
"source": [
201218
"## Dataset"
202-
]
219+
],
220+
"id": "c8a6ec00"
203221
},
204222
{
205223
"cell_type": "code",
@@ -213,14 +231,16 @@
213231
"# If you get an error you need to install tensorflow_datasets from Github.\n",
214232
"train_dataset = input_pipeline.TextDataset(split='train')\n",
215233
"eval_dataset = input_pipeline.TextDataset(split='validation')"
216-
]
234+
],
235+
"id": "5c615ca6"
217236
},
218237
{
219238
"cell_type": "markdown",
220239
"metadata": {},
221240
"source": [
222241
"## Training"
223-
]
242+
],
243+
"id": "7d7c55cd"
224244
},
225245
{
226246
"cell_type": "code",
@@ -231,9 +251,11 @@
231251
"# Get a live update during training - use the \"refresh\" button!\n",
232252
"# (In Jupyter[lab] start \"tensorboard\" in the local directory instead.)\n",
233253
"if 'google.colab' in str(get_ipython()):\n",
254+
" pass\n",
234255
" %load_ext tensorboard\n",
235256
" %tensorboard --logdir=."
236-
]
257+
],
258+
"id": "df9d52ed"
237259
},
238260
{
239261
"cell_type": "code",
@@ -248,7 +270,8 @@
248270
"start_time = time.time()\n",
249271
"optimizer = train.train_and_evaluate(config, workdir=f'./models/{model_name}')\n",
250272
"logging.info('Walltime: %f s', time.time() - start_time)"
251-
]
273+
],
274+
"id": "f159d072"
252275
},
253276
{
254277
"cell_type": "code",
@@ -265,13 +288,15 @@
265288
" #@markdown Note that everbody with the link will be able to see the data.\n",
266289
" upload_data = 'yes' #@param ['yes', 'no']\n",
267290
" if upload_data == 'yes':\n",
291+
" pass\n",
268292
" !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/mnist'"
269-
]
293+
],
294+
"id": "b8e35c72"
270295
}
271296
],
272297
"metadata": {
273298
"accelerator": "GPU"
274299
},
275300
"nbformat": 4,
276301
"nbformat_minor": 0
277-
}
302+
}

0 commit comments

Comments
 (0)