diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 74ad664d6..e6e4eee92 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,7 @@ repos: rev: 0.6.1 hooks: - id: nbstripout + exclude: ^examples/.* args: [ --keep-output, --keep-count, diff --git a/examples/sst2/sst2.ipynb b/examples/sst2/sst2.ipynb index 1a3d0e1bc..cbb7cb81e 100644 --- a/examples/sst2/sst2.ipynb +++ b/examples/sst2/sst2.ipynb @@ -10,14 +10,16 @@ "\n", "Demonstration notebook for\n", "https://github.com/google/flax/tree/main/examples/sst2" - ] + ], + "id": "29fb3c7c" }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Before you start:** Select Runtime -> Change runtime type -> GPU." - ] + ], + "id": "5b526b04" }, { "cell_type": "markdown", @@ -39,14 +41,16 @@ " `train.py`.\n", "4. At any time, feel free to paste code from `train.py` into the notebook\n", " and modify it directly there!" - ] + ], + "id": "89ec78c1" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup" - ] + ], + "id": "7e4ba0dc" }, { "cell_type": "code", @@ -56,7 +60,8 @@ "source": [ "example_directory = 'examples/sst2'\n", "editor_relpaths = ('configs/default.py', 'train.py', 'models.py')" - ] + ], + "id": "ee8021b9" }, { "cell_type": "code", @@ -85,6 +90,7 @@ " os.chdir('/content')\n", " # Download Flax repo from Github.\n", " if not os.path.isdir('flaxrepo'):\n", + " pass\n", " !git clone --depth=1 https://github.com/google/flax flaxrepo\n", " # Copy example files & change directory.\n", " mount_gdrive = 'no' #@param ['yes', 'no']\n", @@ -109,7 +115,8 @@ " open(f'{example_root_path}/{relpath}', 'w').write(\n", " f'## {DISCLAIMER}\\n' + '#' * (len(DISCLAIMER) + 3) + '\\n\\n' + s)\n", " files.view(f'{example_root_path}/{relpath}')" - ] + ], + "id": "36dab290" }, { "cell_type": "code", @@ -119,7 +126,8 @@ "source": [ "# Note: In Colab, above cell changed the working directory.\n", "!pwd" - ] + ], + "id": "700d9428" }, { "cell_type": "code", @@ -129,14 +137,16 @@ "source": [ "# Install SST-2 dependencies.\n", "!pip install -q -r requirements.txt" - ] + ], + "id": "2fbc3e64" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports / Helpers" - ] + ], + "id": "e40c50cf" }, { "cell_type": "code", @@ -153,7 +163,8 @@ " import os\n", " os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'\n", "jax.devices()" - ] + ], + "id": "703f04fb" }, { "cell_type": "code", @@ -172,7 +183,8 @@ "\n", "# Make sure the GPU is for JAX, not for TF.\n", "tf.config.experimental.set_visible_devices([], 'GPU')" - ] + ], + "id": "32e12a97" }, { "cell_type": "code", @@ -186,20 +198,26 @@ "# Any changes you make to train.py will appear automatically.\n", "%load_ext autoreload\n", "%autoreload 2\n", - "import train\n", - "import models\n", - "import vocabulary\n", - "import input_pipeline\n", - "from configs import default as config_lib\n", + "try:\n", + " import train\n", + " import models\n", + " import vocabulary\n", + " import input_pipeline\n", + " from configs import default as config_lib\n", + "except ModuleNotFoundError:\n", + " # Local imports may not be available in all contexts\n", + " pass\n", "config = config_lib.get_config()" - ] + ], + "id": "94ece24d" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" - ] + ], + "id": "c8a6ec00" }, { "cell_type": "code", @@ -213,14 +231,16 @@ "# If you get an error you need to install tensorflow_datasets from Github.\n", "train_dataset = input_pipeline.TextDataset(split='train')\n", "eval_dataset = input_pipeline.TextDataset(split='validation')" - ] + ], + "id": "5c615ca6" }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training" - ] + ], + "id": "7d7c55cd" }, { "cell_type": "code", @@ -231,9 +251,11 @@ "# Get a live update during training - use the \"refresh\" button!\n", "# (In Jupyter[lab] start \"tensorboard\" in the local directory instead.)\n", "if 'google.colab' in str(get_ipython()):\n", + " pass\n", " %load_ext tensorboard\n", " %tensorboard --logdir=." - ] + ], + "id": "df9d52ed" }, { "cell_type": "code", @@ -248,7 +270,8 @@ "start_time = time.time()\n", "optimizer = train.train_and_evaluate(config, workdir=f'./models/{model_name}')\n", "logging.info('Walltime: %f s', time.time() - start_time)" - ] + ], + "id": "f159d072" }, { "cell_type": "code", @@ -265,8 +288,10 @@ " #@markdown Note that everbody with the link will be able to see the data.\n", " upload_data = 'yes' #@param ['yes', 'no']\n", " if upload_data == 'yes':\n", + " pass\n", " !tensorboard dev upload --one_shot --logdir ./models --name 'Flax examples/mnist'" - ] + ], + "id": "b8e35c72" } ], "metadata": { @@ -274,4 +299,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file