This is the code for the ICML2022 paper:
by Yaodong Yu*, Zitong Yang*, Alexander Wei, Yi Ma, Jacob Steinhardt from UC Berkeley (*equal contribution).
- Python
- Pytorch (1.10.0)
- CUDA
- numpy
We use CIFAR10 (in-distribution dataset) & CIFAR10-C (out-of-distribution datasets) to demonstrate how to compute ProjNorm.
mkdir -p ./data/cifar
curl -O https://zenodo.org/record/2535967/files/CIFAR-10-C.tar
tar -xvf CIFAR-10-C.tar -C data/cifar/
python init_ref_model.py --arch resnet18 --train_epoch 20 --pseudo_iters 500 --lr 0.001 --batch_size 128 --seed 1
arch
: network architecturetrain_epoch
: number of training epochs for training the base modelpseudo_iters
: number of iterations for training the reference modellr
: learning ratebatch_size
: mini-batch sizeseed
: random seed
The base model (base_model
) and reference model (reference_model
) are saved to './checkpoints/{}'.format(arch)
.
python main.py --arch resnet18 --corruption snow --severity 5 --pseudo_iters 500 --lr 0.001 --batch_size 128 --seed 1
arch
: network architecture (apply the same architecture as in Step 1)corruption
: corruption typeseverity
: corruption severitypseudo_iters
: number of iterations for training the reference modellr
: learning ratebatch_size
: mini-batch sizeseed
: random seed (apply the same random seed as in Step 1)
(in-distribution test error
, in-distribution ProjNorm value
)
(out-of-distribution test error
, out-of-distribution ProjNorm value
)
For more experimental and technical details, please check our paper. If you find this useful for your work, please consider citing
@InProceedings{pmlr-v162-yu22i,
title = {Predicting Out-of-Distribution Error with the Projection Norm},
author = {Yu, Yaodong and Yang, Zitong and Wei, Alexander and Ma, Yi and Steinhardt, Jacob},
booktitle = {Proceedings of the 39th International Conference on Machine Learning},
pages = {25721--25746},
year = {2022},
editor = {Chaudhuri, Kamalika and Jegelka, Stefanie and Song, Le and Szepesvari, Csaba and Niu, Gang and Sabato, Sivan},
volume = {162},
series = {Proceedings of Machine Learning Research},
month = {17--23 Jul},
publisher = {PMLR},
pdf = {https://proceedings.mlr.press/v162/yu22i/yu22i.pdf}
}