Skip to content

Commit 677b083

Browse files
committed
For issue 27
1 parent a269023 commit 677b083

File tree

1 file changed

+129
-0
lines changed

1 file changed

+129
-0
lines changed

notebooks/issue-27.ipynb

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"[2021-10-21 07:08:52] Try to use the default NATS-Bench (topology) path from fast_mode=False and path=/Users/xuanyidong/.torch/NATS-tss-v1_0-3ffb9.pickle.pbz2.\n"
13+
]
14+
}
15+
],
16+
"source": [
17+
"from nats_bench import create\n",
18+
"from nats_bench.api_utils import time_string\n",
19+
"import numpy as np\n",
20+
"\n",
21+
"# Create the API for size search space\n",
22+
"api_tss = create(None, \"tss\", fast_mode=False, verbose=False)"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 2,
28+
"metadata": {},
29+
"outputs": [
30+
{
31+
"name": "stdout",
32+
"output_type": "stream",
33+
"text": [
34+
"--------------------------------------------------ImageNet16-120--------------------------------------------------\n",
35+
"Best (10676) architecture on validation: |nor_conv_3x3~0|+|nor_conv_1x1~0|nor_conv_1x1~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n",
36+
"Best (857) architecture on test: |nor_conv_1x1~0|+|nor_conv_1x1~0|nor_conv_3x3~1|+|skip_connect~0|nor_conv_3x3~1|nor_conv_3x3~2|\n",
37+
"using validation ::: validation = 46.73, test = 46.20\n",
38+
"\n",
39+
"using test ::: validation = 46.38, test = 47.31\n",
40+
"\n"
41+
]
42+
}
43+
],
44+
"source": [
45+
"def get_valid_test_acc(api, arch, dataset):\n",
46+
" is_size_space = api.search_space_name == \"size\"\n",
47+
" if dataset == \"cifar10\":\n",
48+
" xinfo = api.get_more_info(\n",
49+
" arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False\n",
50+
" )\n",
51+
" test_acc = xinfo[\"test-accuracy\"]\n",
52+
" xinfo = api.get_more_info(\n",
53+
" arch,\n",
54+
" dataset=\"cifar10-valid\",\n",
55+
" hp=90 if is_size_space else 200,\n",
56+
" is_random=False,\n",
57+
" )\n",
58+
" valid_acc = xinfo[\"valid-accuracy\"]\n",
59+
" else:\n",
60+
" xinfo = api.get_more_info(\n",
61+
" arch, dataset=dataset, hp=90 if is_size_space else 200, is_random=False\n",
62+
" )\n",
63+
" valid_acc = xinfo[\"valid-accuracy\"]\n",
64+
" test_acc = xinfo[\"test-accuracy\"]\n",
65+
" return (\n",
66+
" valid_acc,\n",
67+
" test_acc,\n",
68+
" \"validation = {:.2f}, test = {:.2f}\\n\".format(valid_acc, test_acc),\n",
69+
" )\n",
70+
"\n",
71+
"def find_best_valid(api, dataset):\n",
72+
" all_valid_accs, all_test_accs = [], []\n",
73+
" for index, arch in enumerate(api):\n",
74+
" valid_acc, test_acc, perf_str = get_valid_test_acc(api, index, dataset)\n",
75+
" all_valid_accs.append((index, valid_acc))\n",
76+
" all_test_accs.append((index, test_acc))\n",
77+
" best_valid_index = sorted(all_valid_accs, key=lambda x: -x[1])[0][0]\n",
78+
" best_test_index = sorted(all_test_accs, key=lambda x: -x[1])[0][0]\n",
79+
"\n",
80+
" print(\"-\" * 50 + \"{:10s}\".format(dataset) + \"-\" * 50)\n",
81+
" print(\n",
82+
" \"Best ({:}) architecture on validation: {:}\".format(\n",
83+
" best_valid_index, api[best_valid_index]\n",
84+
" )\n",
85+
" )\n",
86+
" print(\n",
87+
" \"Best ({:}) architecture on test: {:}\".format(\n",
88+
" best_test_index, api[best_test_index]\n",
89+
" )\n",
90+
" )\n",
91+
" _, _, perf_str = get_valid_test_acc(api, best_valid_index, dataset)\n",
92+
" print(\"using validation ::: {:}\".format(perf_str))\n",
93+
" _, _, perf_str = get_valid_test_acc(api, best_test_index, dataset)\n",
94+
" print(\"using test ::: {:}\".format(perf_str))\n",
95+
"\n",
96+
"dataset = \"ImageNet16-120\"\n",
97+
"find_best_valid(api_tss, dataset)"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"metadata": {},
104+
"outputs": [],
105+
"source": []
106+
}
107+
],
108+
"metadata": {
109+
"kernelspec": {
110+
"display_name": "Python 3",
111+
"language": "python",
112+
"name": "python3"
113+
},
114+
"language_info": {
115+
"codemirror_mode": {
116+
"name": "ipython",
117+
"version": 3
118+
},
119+
"file_extension": ".py",
120+
"mimetype": "text/x-python",
121+
"name": "python",
122+
"nbconvert_exporter": "python",
123+
"pygments_lexer": "ipython3",
124+
"version": "3.8.8"
125+
}
126+
},
127+
"nbformat": 4,
128+
"nbformat_minor": 4
129+
}

0 commit comments

Comments
 (0)