1010 " \n " ,
1111 " Demonstration notebook for\n " ,
1212 " https://github.com/google/flax/tree/main/examples/sst2"
13- ]
13+ ],
14+ "id" : " 29fb3c7c"
1415 },
1516 {
1617 "cell_type" : " markdown" ,
1718 "metadata" : {},
1819 "source" : [
1920 " **Before you start:** Select Runtime -> Change runtime type -> GPU."
20- ]
21+ ],
22+ "id" : " 5b526b04"
2123 },
2224 {
2325 "cell_type" : " markdown" ,
3941 " `train.py`.\n " ,
4042 " 4. At any time, feel free to paste code from `train.py` into the notebook\n " ,
4143 " and modify it directly there!"
42- ]
44+ ],
45+ "id" : " 89ec78c1"
4346 },
4447 {
4548 "cell_type" : " markdown" ,
4649 "metadata" : {},
4750 "source" : [
4851 " ## Setup"
49- ]
52+ ],
53+ "id" : " 7e4ba0dc"
5054 },
5155 {
5256 "cell_type" : " code" ,
5660 "source" : [
5761 " example_directory = 'examples/sst2'\n " ,
5862 " editor_relpaths = ('configs/default.py', 'train.py', 'models.py')"
59- ]
63+ ],
64+ "id" : " ee8021b9"
6065 },
6166 {
6267 "cell_type" : " code" ,
8590 " os.chdir('/content')\n " ,
8691 " # Download Flax repo from Github.\n " ,
8792 " if not os.path.isdir('flaxrepo'):\n " ,
93+ " pass\n " ,
8894 " !git clone --depth=1 https://github.com/google/flax flaxrepo\n " ,
8995 " # Copy example files & change directory.\n " ,
9096 " mount_gdrive = 'no' #@param ['yes', 'no']\n " ,
109115 " open(f'{example_root_path}/{relpath}', 'w').write(\n " ,
110116 " f'## {DISCLAIMER}\\ n' + '#' * (len(DISCLAIMER) + 3) + '\\ n\\ n' + s)\n " ,
111117 " files.view(f'{example_root_path}/{relpath}')"
112- ]
118+ ],
119+ "id" : " 36dab290"
113120 },
114121 {
115122 "cell_type" : " code" ,
119126 "source" : [
120127 " # Note: In Colab, above cell changed the working directory.\n " ,
121128 " !pwd"
122- ]
129+ ],
130+ "id" : " 700d9428"
123131 },
124132 {
125133 "cell_type" : " code" ,
129137 "source" : [
130138 " # Install SST-2 dependencies.\n " ,
131139 " !pip install -q -r requirements.txt"
132- ]
140+ ],
141+ "id" : " 2fbc3e64"
133142 },
134143 {
135144 "cell_type" : " markdown" ,
136145 "metadata" : {},
137146 "source" : [
138147 " ## Imports / Helpers"
139- ]
148+ ],
149+ "id" : " e40c50cf"
140150 },
141151 {
142152 "cell_type" : " code" ,
153163 " import os\n " ,
154164 " os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n " ,
155165 " jax.devices()"
156- ]
166+ ],
167+ "id" : " 703f04fb"
157168 },
158169 {
159170 "cell_type" : " code" ,
172183 " \n " ,
173184 " # Make sure the GPU is for JAX, not for TF.\n " ,
174185 " tf.config.experimental.set_visible_devices([], 'GPU')"
175- ]
186+ ],
187+ "id" : " 32e12a97"
176188 },
177189 {
178190 "cell_type" : " code" ,
186198 " # Any changes you make to train.py will appear automatically.\n " ,
187199 " %load_ext autoreload\n " ,
188200 " %autoreload 2\n " ,
189- " import train\n " ,
190- " import models\n " ,
191- " import vocabulary\n " ,
192- " import input_pipeline\n " ,
193- " from configs import default as config_lib\n " ,
201+ " try:\n " ,
202+ " import train\n " ,
203+ " import models\n " ,
204+ " import vocabulary\n " ,
205+ " import input_pipeline\n " ,
206+ " from configs import default as config_lib\n " ,
207+ " except ModuleNotFoundError:\n " ,
208+ " # Local imports may not be available in all contexts\n " ,
209+ " pass\n " ,
194210 " config = config_lib.get_config()"
195- ]
211+ ],
212+ "id" : " 94ece24d"
196213 },
197214 {
198215 "cell_type" : " markdown" ,
199216 "metadata" : {},
200217 "source" : [
201218 " ## Dataset"
202- ]
219+ ],
220+ "id" : " c8a6ec00"
203221 },
204222 {
205223 "cell_type" : " code" ,
213231 " # If you get an error you need to install tensorflow_datasets from Github.\n " ,
214232 " train_dataset = input_pipeline.TextDataset(split='train')\n " ,
215233 " eval_dataset = input_pipeline.TextDataset(split='validation')"
216- ]
234+ ],
235+ "id" : " 5c615ca6"
217236 },
218237 {
219238 "cell_type" : " markdown" ,
220239 "metadata" : {},
221240 "source" : [
222241 " ## Training"
223- ]
242+ ],
243+ "id" : " 7d7c55cd"
224244 },
225245 {
226246 "cell_type" : " code" ,
231251 " # Get a live update during training - use the \" refresh\" button!\n " ,
232252 " # (In Jupyter[lab] start \" tensorboard\" in the local directory instead.)\n " ,
233253 " if 'google.colab' in str(get_ipython()):\n " ,
254+ " pass\n " ,
234255 " %load_ext tensorboard\n " ,
235256 " %tensorboard --logdir=."
236- ]
257+ ],
258+ "id" : " df9d52ed"
237259 },
238260 {
239261 "cell_type" : " code" ,
248270 " start_time = time.time()\n " ,
249271 " optimizer = train.train_and_evaluate(config, workdir=f'./models/{model_name}')\n " ,
250272 " logging.info('Walltime: %f s', time.time() - start_time)"
251- ]
273+ ],
274+ "id" : " f159d072"
252275 },
253276 {
254277 "cell_type" : " code" ,
265288 " #@markdown Note that everbody with the link will be able to see the data.\n " ,
266289 " upload_data = 'yes' #@param ['yes', 'no']\n " ,
267290 " if upload_data == 'yes':\n " ,
291+ " pass\n " ,
268292 " !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/mnist'"
269- ]
293+ ],
294+ "id" : " b8e35c72"
270295 }
271296 ],
272297 "metadata" : {
273298 "accelerator" : " GPU"
274299 },
275300 "nbformat" : 4 ,
276301 "nbformat_minor" : 0
277- }
302+ }
0 commit comments