Skip to content

Commit a45ed35

Browse files
committed
add pure pso optimization method
1 parent 6730bf8 commit a45ed35

File tree

3 files changed

+28
-13
lines changed

3 files changed

+28
-13
lines changed

main.py

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def main(args):
5454
parser.add_argument('--xlimit_min', default=-10, type=float, help='xlimit_min')
5555
parser.add_argument('--weight_particle_optmized_location', default=0.33, type=float, help='weight_particle_optmized_location')
5656
parser.add_argument('--weight_global_optmized_location', default=0.33, type=float, help='weight_global_optmized_location')
57+
parser.add_argument('--use_sgd', default=True, type=bool, help='use_sgd')
5758

5859
# trainer
5960
parser.add_argument('--divice', default="cuda" if torch.cuda.is_available() else "cpu", type=str, help='divice')

psosgd_optimizer.py

+22-11
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __setstate__(self, state):
5454
group.setdefault('nesterov', False)
5555

5656
@torch.no_grad()
57-
def step(self, local_best_param_group, global_best_param_group, is_psosgd, closure=None):
57+
def step(self, local_best_param_group, global_best_param_group, use_pso, use_sgd, closure=None):
5858
"""Performs a single optimization step.
5959
6060
Arguments:
@@ -80,37 +80,48 @@ def step(self, local_best_param_group, global_best_param_group, is_psosgd, closu
8080
weight_global_optmized_location = group['weight_global_optmized_location']
8181

8282
for p_index, p in enumerate(group['params']):
83-
if is_psosgd:
83+
if use_pso:
8484
local_best_p = local_best_param_group[p_index]
8585
global_best_p = global_best_param_group[p_index]
86+
8687
if p.grad is None:
8788
continue
88-
d_p = p.grad
89-
if weight_decay != 0:
90-
d_p = d_p.add(p, alpha=weight_decay)
89+
90+
if use_sgd:
91+
d_p = p.grad
92+
if weight_decay != 0:
93+
d_p = d_p.add(p, alpha=weight_decay)
94+
else:
95+
d_p = -(vlimit_min + (vlimit_max - vlimit_min) * torch.rand(p.shape))
96+
9197
if momentum != 0:
9298
param_state = self.state[p]
9399
if 'momentum_buffer' not in param_state:
94100
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
95101
else:
96102
buf = param_state['momentum_buffer']
97-
# buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
103+
98104
buf.mul_(momentum)
99-
if is_psosgd:
105+
if use_pso:
100106
buf.sub_(local_best_p.sub(p), alpha=weight_particle_optmized_location * random.random())
101107
buf.sub_(global_best_p.sub(p), alpha=weight_global_optmized_location * random.random())
102-
buf.add_(d_p, alpha=1-dampening)
103108

104-
if is_psosgd:
109+
if use_sgd:
110+
buf.add_(d_p, alpha=1-dampening)
111+
112+
if use_pso:
105113
buf[buf > vlimit_max] = vlimit_max
106114
buf[buf < vlimit_min] = vlimit_min
107115

108-
if nesterov:
116+
if use_sgd and nesterov:
109117
d_p = d_p.add(buf, alpha=momentum)
110118
else:
111119
d_p = buf
112120

113-
p.add_(d_p, alpha=-lr)
121+
if use_sgd:
122+
p.add_(d_p, alpha=-lr)
123+
else: # When SGD is not used, the learning rate parameter lr is invalid.
124+
p.add_(d_p, alpha=-1)
114125
# p[p>xlimit_max] = xlimit_max
115126
# p[p<xlimit_min] = xlimit_min
116127

psosgd_trainer.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(self,
2222
optimizer_config,
2323
device = "cuda" if torch.cuda.is_available() else "cpu",
2424
n_particle = 5,
25-
output_path = 'output', **kwargs):
25+
output_path = 'output',
26+
use_sgd = True, **kwargs):
2627

2728
# 预留模型参数
2829
self.model_config = model_config
@@ -40,6 +41,8 @@ def __init__(self,
4041

4142
self.output_path = output_path
4243

44+
self.use_sgd = use_sgd
45+
4346

4447
class PSOSGD_Trainer:
4548

@@ -88,7 +91,7 @@ def train(self, data_loader, loss_fn, epochs):
8891
global_best_param_group = (batch_losses[i], [torch.clone(param).detach() for param in self.models[i].parameters()])
8992

9093
for i in range(self.config.n_particle):
91-
self.optimizers[i].step(local_best_param_groups[i][1], global_best_param_group[1], self.config.n_particle != 1)
94+
self.optimizers[i].step(local_best_param_groups[i][1], global_best_param_group[1], self.config.n_particle != 1, self.config.use_sgd)
9295

9396
losses.append(batch_losses)
9497

0 commit comments

Comments
 (0)