As a toy project, this repository provides implementations of the first order Differentiable Architecture Search (DARTS) on the fashion-mnist dataset in three different frameworks.
I've just finished this project on my Xiaomi laptop with a 6GB Nvidia RTX 2060 GPU card. For more information about the development environment please refer to this page.
All of the three flavors (frameworks) have similar interfaces and storylines which are mostly self-explained.
To get started, just have the Main.ipynb
run and follow its lead.
If you'd prefer not to play with it in Jupyter Notebook, do not forget to make the change from tqdm.notebook
to tqdm
.
The dataset (.gz) has been put under darts_pt_pp_tf/data/fmnist/
.
The way I tune the configurations is to modify the config.yml
.
- We prepare primary operations (op).
- With op at hand, we are free to construct Cells. Normal Cell with
stride=1
, Reduction Cell withstride=2
. - Define the Kernel network which is piles of Cells.
- Encapsulate the Kernel network with the Shell network who has two more trainable parameters ---
alphas
. - Searching process:
- Update the trainable parameters of Kernel.
- Update
alphas
in Shell. - Save the best-searched Cells.
- Training process:
- Reconstruct the Kernel network with searched Cells.
- Training and Validation.
- Save the best model.
- Prediction process:
- Load the best model.
- Prediction.
The parameter update process and the training, validation processes all follow the procedures like:
- Get x, y from the data pipeline.
- Get loss value (forward).
- Backpropagation.
- Gradient descent on certain parameters.
-
The
affine
argument in Batch Normalization is setFalse
for the Searching process andTrue
for the Training process. -
For
ReduceLROnPlateau
:patience=10, factor=0.5
We didn't put these arguments inconfig.yml
for simplicity. -
Don't iter the variable returned by
fluid.layers.create_parameter
, it will not stop at the end but give out the out boundary error. -
For tf-2.2.0: we need this:
tf.config.experimental.set_memory_growth(gpu_check[0], True)
otherwise, there would be the OOM problem on my laptop.