@@ -54,7 +54,7 @@ def __setstate__(self, state):
54
54
group .setdefault ('nesterov' , False )
55
55
56
56
@torch .no_grad ()
57
- def step (self , local_best_param_group , global_best_param_group , is_psosgd , closure = None ):
57
+ def step (self , local_best_param_group , global_best_param_group , use_pso , use_sgd , closure = None ):
58
58
"""Performs a single optimization step.
59
59
60
60
Arguments:
@@ -80,37 +80,48 @@ def step(self, local_best_param_group, global_best_param_group, is_psosgd, closu
80
80
weight_global_optmized_location = group ['weight_global_optmized_location' ]
81
81
82
82
for p_index , p in enumerate (group ['params' ]):
83
- if is_psosgd :
83
+ if use_pso :
84
84
local_best_p = local_best_param_group [p_index ]
85
85
global_best_p = global_best_param_group [p_index ]
86
+
86
87
if p .grad is None :
87
88
continue
88
- d_p = p .grad
89
- if weight_decay != 0 :
90
- d_p = d_p .add (p , alpha = weight_decay )
89
+
90
+ if use_sgd :
91
+ d_p = p .grad
92
+ if weight_decay != 0 :
93
+ d_p = d_p .add (p , alpha = weight_decay )
94
+ else :
95
+ d_p = - (vlimit_min + (vlimit_max - vlimit_min ) * torch .rand (p .shape ))
96
+
91
97
if momentum != 0 :
92
98
param_state = self .state [p ]
93
99
if 'momentum_buffer' not in param_state :
94
100
buf = param_state ['momentum_buffer' ] = torch .clone (d_p ).detach ()
95
101
else :
96
102
buf = param_state ['momentum_buffer' ]
97
- # buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
103
+
98
104
buf .mul_ (momentum )
99
- if is_psosgd :
105
+ if use_pso :
100
106
buf .sub_ (local_best_p .sub (p ), alpha = weight_particle_optmized_location * random .random ())
101
107
buf .sub_ (global_best_p .sub (p ), alpha = weight_global_optmized_location * random .random ())
102
- buf .add_ (d_p , alpha = 1 - dampening )
103
108
104
- if is_psosgd :
109
+ if use_sgd :
110
+ buf .add_ (d_p , alpha = 1 - dampening )
111
+
112
+ if use_pso :
105
113
buf [buf > vlimit_max ] = vlimit_max
106
114
buf [buf < vlimit_min ] = vlimit_min
107
115
108
- if nesterov :
116
+ if use_sgd and nesterov :
109
117
d_p = d_p .add (buf , alpha = momentum )
110
118
else :
111
119
d_p = buf
112
120
113
- p .add_ (d_p , alpha = - lr )
121
+ if use_sgd :
122
+ p .add_ (d_p , alpha = - lr )
123
+ else : # When SGD is not used, the learning rate parameter lr is invalid.
124
+ p .add_ (d_p , alpha = - 1 )
114
125
# p[p>xlimit_max] = xlimit_max
115
126
# p[p<xlimit_min] = xlimit_min
116
127
0 commit comments