-
Notifications
You must be signed in to change notification settings - Fork 146
Description
Hello there,
First of all, thanks a lot for publishing your code and actively answering questions on GitHub!!
I ran into problems when I instantiated a new instance of the FEDformer model class and loading weights of a previous run. I could not reproduce the same scores. Moreover, the scores substantially differed each time I reloaded the weights.
Hence, I checked whether model outputs stayed the same when reloading previous model weights while feeding consistent inputs. Therefore, I have created 3 model outputs, which should all be the same:
- creating a model instance and running inference
- creating a NEW model instance, loading the model weights from 1) and running inference
- loading the model instance from 1) as a whole and running inference
1., 2., 3. should be the same. This is true if model version is Wavelets
. However, if the model version is Fourier
this leads to 1. != 2.. As a result, I am not able to use the weights of a Fourier
model I trained. I could save and load the model instance as a whole, but I am still wondering why this is the case. Or did I miss something?
This is the code to reproduce my findings:
class Configs(object):
ab = 0
modes = 32
mode_select = 'random'
# version = 'Wavelets'
version = 'Fourier'
moving_avg = [12, 24]
L = 1
base = 'legendre'
cross_activation = 'tanh'
seq_len = 96
label_len = 48
pred_len = 96
output_attention = False
enc_in = 7
dec_in = 7
d_model = 16
embed = 'timeF'
dropout = 0.05
freq = 'h'
factor = 1
n_heads = 8
d_ff = 16
e_layers = 2
d_layers = 1
c_out = 7
activation = 'gelu'
wavelet = 0
# consistent input of just ones
configs = Configs()
enc = torch.ones([3, configs.seq_len, 7])
enc_mark = torch.ones([3, configs.seq_len, 4])
dec = torch.ones([3, configs.seq_len//2+configs.pred_len, 7])
dec_mark = torch.ones([3, configs.seq_len//2+configs.pred_len, 4])
# 1) creating a model instance and running inference
model = Model(configs)
model.eval()
out_1 = model.forward(enc, enc_mark, dec, dec_mark)
# saving only the model weights like it is done in the training
torch.save(model.state_dict(), "./model_weights.pt")
# saving the class instance as a whole
torch.save(model, "./model_class_instance.pt")
# 2) creating a new model instance and loading previous model weights
model = Model(configs)
model.load_state_dict(torch.load("./model_weights.pt"))
model.eval()
out_2 = model.forward(enc, enc_mark, dec, dec_mark) # <-- this leads to inconistent results when 'version' is "Fourier"
# 3) loading the whole model instance
model = torch.load("./model_class_instance.pt")
model.eval()
out_3 = model.forward(enc, enc_mark, dec, dec_mark)
# this should always output: True, True
print(torch.equal(out_1, out_2), torch.equal(out_1, out_3))
I appreciate any help on this matter!!