-
Notifications
You must be signed in to change notification settings - Fork 0
/
Experiment_Scaling_MLPs.py
121 lines (94 loc) · 4.38 KB
/
Experiment_Scaling_MLPs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch.nn as nn
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dims, output_dim):
super(MLP, self).__init__()
layers = []
current_dim = input_dim
for dim in hidden_dims:
layers.append(nn.Linear(current_dim, dim))
layers.append(nn.SiLU())
current_dim = dim
layers.append(nn.Linear(current_dim, output_dim))
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x.flatten(1))
small_mlp = MLP(input_dim=784, hidden_dims=[32], output_dim=10)
medium_mlp = MLP(input_dim=784, hidden_dims=[128, 64], output_dim=10)
big_mlp = MLP(input_dim=784, hidden_dims=[512, 256, 128, 64, 32], output_dim=10)
mlps = {'32': small_mlp, '128-64': medium_mlp, '512-256-128-64-32': big_mlp}
from utils.core_logic import directions, dists_to_models, dirs_and_dists, eval_ensemble, curvature_scale_analysis
from utils.plots import plot_df, save_fig_with_cfg
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import pandas as pd
import os
torch.set_grad_enabled(False)
CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda" if CUDA else "cpu")
dataframes = []
for name, model in mlps.items():
if CUDA:
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
# config dict for the experiment
cfg = {
'device': DEVICE,
'precision': torch.float32, # float32 or float64
'center': model.to(DEVICE), # the point in parameter space to start from
'dataloader': DataLoader(
MNIST('./data', download=True, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=100,
num_workers=min(os.cpu_count(),10),
pin_memory=CUDA, # This is super important for speed!
),
'criterion': F.cross_entropy,
'grad': False, # should we look in the direction of gradient ascent and descent?
'subspaces': [], # project directions onto modules with these substrings in their names. See project_to_module(). Leave empty for no projections.
'rand_dirs': 40 if CUDA else 0, # number of random directions to add, in addition to the gradient and radial directions
'max_oom': 2 if CUDA else 0, # furthest sample will be 10**max_oom from the center
'min_oom': -4 if CUDA else -1, # closest sample will be 10**min_oom from the center
}
torch.set_default_dtype(cfg['precision']) # precision global default
# SUPER IMPORTANT
# float64: - Finite differences noise on the oom of 1e-15.
# - Transition from 'Quadratic' to 'Noise' behaviour at parameter distances on the order of 1e-6
# - Example takes 350 seconds
# - ca. 9 GB GPU RAM
# float32: - Finite differences noise on the oom of 1e-8.
# - Transition from 'Quadratic' to 'Noise' behaviour at parameter distances on the order of 1e-3
# - Example takes 35 seconds
# - ca. 4.5 GB GPU RAM
print(f'Running on {cfg["device"]}')
dirs = directions(cfg)
df = dirs_and_dists(dirs.keys(), cfg['min_oom'], cfg['max_oom'])
df = dists_to_models(df, dirs, cfg['center'])
# add the center model
ensemble_list = [cfg['center']] + list(df['Model'])
df.drop(columns='Model', inplace=True)
ensemble_loss = eval_ensemble(ensemble_list, cfg['dataloader'], cfg['criterion'], cfg['device'])
# Convert the loss to a list and unpack
center_loss, *df['Loss'] = ensemble_loss.tolist()
# Add the center loss to the DataFrame at 'Distance' = 0
for direction in df.index.get_level_values('Direction'):
df.loc[(direction, 0), ['Distance', 'Loss']] = [0., center_loss]
df.sort_index(inplace=True)
df = curvature_scale_analysis(df)
dataframes.append(df)
# average over all directions
means = []
for df in dataframes:
means.append(df.groupby('Step').mean())
means[0]
combined_df = pd.concat(means, keys=mlps.keys())
#Plotting with models as 'Directions'
combined_df.index.names = ['Direction', 'Step']
figs = plot_df(combined_df, 'MLPs of different sizes on MNIST')
for fig in figs:
save_fig_with_cfg(dir='automatic_figs',fig=fig, config=cfg)