Skip to content

Inconsistent results when recreating model instance #77

@b-nils

Description

@b-nils

Hello there,

First of all, thanks a lot for publishing your code and actively answering questions on GitHub!!

I ran into problems when I instantiated a new instance of the FEDformer model class and loading weights of a previous run. I could not reproduce the same scores. Moreover, the scores substantially differed each time I reloaded the weights.

Hence, I checked whether model outputs stayed the same when reloading previous model weights while feeding consistent inputs. Therefore, I have created 3 model outputs, which should all be the same:

  1. creating a model instance and running inference
  2. creating a NEW model instance, loading the model weights from 1) and running inference
  3. loading the model instance from 1) as a whole and running inference

1., 2., 3. should be the same. This is true if model version is Wavelets. However, if the model version is Fourier this leads to 1. != 2.. As a result, I am not able to use the weights of a Fourier model I trained. I could save and load the model instance as a whole, but I am still wondering why this is the case. Or did I miss something?

This is the code to reproduce my findings:

class Configs(object):
    ab = 0
    modes = 32
    mode_select = 'random'
    # version = 'Wavelets'
    version = 'Fourier'
    moving_avg = [12, 24]
    L = 1
    base = 'legendre'
    cross_activation = 'tanh'
    seq_len = 96
    label_len = 48
    pred_len = 96
    output_attention = False
    enc_in = 7
    dec_in = 7
    d_model = 16
    embed = 'timeF'
    dropout = 0.05
    freq = 'h'
    factor = 1
    n_heads = 8
    d_ff = 16
    e_layers = 2
    d_layers = 1
    c_out = 7
    activation = 'gelu'
    wavelet = 0
    
# consistent input of just ones
configs = Configs()
enc = torch.ones([3, configs.seq_len, 7])
enc_mark = torch.ones([3, configs.seq_len, 4])
dec = torch.ones([3, configs.seq_len//2+configs.pred_len, 7])
dec_mark = torch.ones([3, configs.seq_len//2+configs.pred_len, 4])

# 1) creating a model instance and running inference
model = Model(configs)
model.eval()
out_1 = model.forward(enc, enc_mark, dec, dec_mark)
# saving only the model weights like it is done in the training
torch.save(model.state_dict(), "./model_weights.pt")
# saving the class instance as a whole
torch.save(model, "./model_class_instance.pt")

# 2) creating a new model instance and loading previous model weights
model = Model(configs)
model.load_state_dict(torch.load("./model_weights.pt"))
model.eval()
out_2 = model.forward(enc, enc_mark, dec, dec_mark)  # <-- this leads to inconistent results when 'version' is "Fourier"

# 3) loading the whole model instance
model = torch.load("./model_class_instance.pt")
model.eval()
out_3 = model.forward(enc, enc_mark, dec, dec_mark)

# this should always output: True, True
print(torch.equal(out_1, out_2), torch.equal(out_1, out_3))

I appreciate any help on this matter!!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions