forked from tensorflow/models
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain-cifar10-adaptive.sh
executable file
·104 lines (81 loc) · 1.84 KB
/
train-cifar10-adaptive.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/bin/sh
set -e
now() { date +%s; }
measure() {
local begin=$(now)
$@
local end=$(now)
local duration=$((end - begin))
echo "$@ took ${duration}s"
}
cd $(dirname $0)
. ./config.sh
MODEL_PATH=$PWD
export PYTHONWARNINGS='ignore'
export PYTHONPATH=$MODEL_PATH
export TF_CPP_MIN_LOG_LEVEL=3
data_dir=$HOME/var/data/cifar
model_dir_prefix=$HOME/tmp/cifar10
model_dir=$model_dir_prefix
cap=4
H=127.0.0.1:$cap
port_range=40001-40004
kungfu_run_flags() {
echo -q
echo -logdir logs/$job_id
echo -logfile kungfu-run.log
echo -port 40000
echo -port-range $port_range
echo -H $H
echo -np
}
kungfu_run() {
kungfu-run $(kungfu_run_flags) $@
}
# export CUDA_VISIBLE_DEVICES=3
join() {
local IFS=','
echo "$*"
}
hooks() {
echo kungfu_log_step_hook
echo kungfu_load_init_model_hook
# echo kungfu_save_model_hook
# echo kungfu_consistency_check_hook
# echo kungfu_inspect_graph_hook
# echo kungfu_change_batch_size_hook
echo kungfu_policy
}
app_flags() {
echo -md $model_dir
echo -dd $data_dir
echo -hooks $(join $(hooks))
echo -kungfu_opt gns
}
train_cifar10() {
local epochs=$1
local np=$2
local single_bs=$3
export START_TIMESTAMP=$(date +%s)
job_id=adaptive-bs-$single_bs
model_dir=$model_dir_prefix/$job_id
if [ -d $model_dir ]; then
rm -fr $model_dir
fi
kungfu_run $np \
python3 \
official/resnet/cifar10_main.py \
$(app_flags) \
-bs $single_bs \
-ng $np \
-te $epochs
}
run_all() {
local epochs=$cfg_epochs
measure train_cifar10 $epochs 4 32 # adaptive batch size starting with 4x32
# measure train_cifar10 $epochs 4 64
# measure train_cifar10 $epochs 4 128
# measure train_cifar10 $epochs 4 256
# measure train_cifar10 $epochs 4 512
}
measure run_all