The PyTorch implementation of Att-DARTS: Differentiable Neural Architecture Search for Attention.
The codes are based on https://github.com/dragen1860/DARTS-PyTorch.
- Python == 3.7
- PyTorch == 1.0.1
- torchvision == 0.2.2
- pillow == 6.2.1
- numpy
- graphviz
- requests
- tqdm
We recommend downloading PyTorch from here.
- CIFAR-10/100: automatically downloaded by torchvision to
datafolder. - ImageNet (ILSVRC2012 version): manually downloaded following the instructions here.
| CIFAR-10 | CIFAR-100 | Params(M) | |
|---|---|---|---|
| DARTS | 2.76 ± 0.09 | 16.69 ± 0.28 | 3.3 |
| Att-DARTS | 2.54 ± 0.10 | 16.54 ± 0.40 | 3.2 |
| top-1 | top-5 | Params(M) | |
|---|---|---|---|
| DARTS | 26.7 | 8.7 | 4.7 |
| Att-DARTS | 26.0 | 8.5 | 4.6 |
Our script occupies all available GPUs. Please set environment CUDA_VISIBLE_DEVICES.
To carry out architecture search using 2nd-order approximation, run:
python train_search.py --unrolledThe found cell will be saved in genotype.json.
Our resultant Att_DARTS is written in genotypes.py.
Inserting an attention at other locations is supported through the --location flag.
The locations are specified at AttLocation in model_search.py.
To evaluate our best cells by training from scratch, run:
python train_CIFAR10.py --auxiliary --cutout --arch Att_DARTS # CIFAR-10
python train_CIFAR100.py --auxiliary --cutout --arch Att_DARTS # CIFAR-100
python train_ImageNet.py --auxiliary --arch Att_DARTS # ImageNetCustomized architectures are supported through the --arch flag once specified in genotypes.py.
Also, you can designate the search result in .json through the --arch_path flag:
python train_CIFAR10.py --auxiliary --cutout --arch_path ${PATH} # CIFAR-10
python train_CIFAR100.py --auxiliary --cutout --arch_path ${PATH} # CIFAR-100
python train_ImageNet.py --auxiliary --arch_path ${PATH} # ImageNetwhere ${PATH} should be replaced by the path to the .json.
The trained model is saved in trained.pt.
After training, the test script automatically runs.
Also, you can always test the trained.pt as indicated below.
To test a pretrained model saved in .pt , run:
python test_CIFAR10.py --auxiliary --model_path ${PATH} --arch Att_DARTS # CIFAR-10
python test_CIFAR100.py --auxiliary --model_path ${PATH} --arch Att_DARTS # CIFAR-100
python test_imagenet.py --auxiliary --model_path ${PATH} --arch Att_DARTS # ImageNetwhere ${PATH} should be replaced by the path to .pt.
You can designate our pretrained models (cifar10_att.pt, cifar100_att.pt, imagenet_att.pt) or the saved trained.pt in Architecture Evaluation.
Also, we support customized architectures specified in genotypes.py through the --arch flag, or architectures specified in .json through the --arch_path flag.
You can visualize the found cells in genotypes.py.
For example, you can visualize Att-DARTS running:
python visualize.py Att_DARTSAlso, you can visualize the saved cell in .json:
python visualize.py genotype.jsonThis repository includes the following attentions:
- Squeeze-and-Excitation (paper / code (unofficial))
- Gather-Excite (paper / code (unofficial))
- BAM (paper / code)
- CBAM (paper / code)
- A2-Nets (paper / code (unofficial))
@inproceedings{att-darts2020IJCNN,
author = {Nakai, Kohei and Matsubara, Takashi and Uehara, Kuniaki},
booktitle = {The International Joint Conference on Neural Networks (IJCNN)},
title = {{Att-DARTS: Differentiable Neural Architecture Search for Attention}},
year = {2020}
}