Skip to content

Challenges with registering the pytorch inference model object 'output = runner.model.model.forward(input_state)' #424

@anna-inf-met

Description

@anna-inf-met

With respect to the model agnostic interpretability tool I am building (which we have working on Aurora and FastNet currently). We would like to try it with AIFS, however the example I have through Anemoi is proving difficult to unpick in a manner that I might get a pytorch,nn,module inference model object registered with a library called Zennit (it has to be Zennit as we have extended this library significantly), and then do a simple forward pass on the model with some “input state”. This seems like a relatively innocuous piece of code that can be done in other ways with Anemoi – however I am using the pytorch autograd ‘under the hood’ and registering hooks so that I might rewrite rules on backward passes so I need to explicitly do it in the following way (pseudo code follows):

zennit_comp.register(runner.model.model)
 
# Forward pass
output = runner.model.model.forward(input_state)
 
# Choose a target tensor for backward (example: output["fields"]["2t"])
output_tensor = output["fields"]["2t"]
output_tensor.retain_grad()
output_tensor.grad = None
 
# Backward pass
output_tensor.backward(torch.ones_like(output_tensor))
 
# Remove Zennit hooks
zennit_comp.remove()

It is getting the ‘input_state’ part from Anemoi that I am struggling with.

This is as far as I got - I worked out from doing:

for state in runner.run(input_state=input_state, lead_time=12):
    print(state)
 
"q_50", "q_100", "q_150", "q_200", "q_250", "q_300", "q_400", "q_500", "q_600", "q_700", "q_850", "q_925", "q_1000",
    "t_50", "t_100", "t_150", "t_200", "t_250", "t_300", "t_400", "t_500", "t_600", "t_700", "t_850", "t_925", "t_1000",
    "u_50", "u_100", "u_150", "u_200", "u_250", "u_300", "u_400", "u_500", "u_600", "u_700", "u_850", "u_925", "u_1000",
    "v_50", "v_100", "v_150", "v_200", "v_250", "v_300", "v_400", "v_500", "v_600", "v_700", "v_850", "v_925", "v_1000",
    "w_50", "w_100", "w_150", "w_200", "w_250", "w_300", "w_400", "w_500", "w_600", "w_700", "w_850", "w_925", "w_1000",
    "z_50", "z_100", "z_150", "z_200", "z_250", "z_300", "z_400", "z_500", "z_600", "z_700", "z_850", "z_925", "z_1000",
    "10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "cp", "tp", "100u", "100v", "hcc", "lcc", "mcc", "ro", "sf", "ssrd",
    "stl1", "stl2", "strd", "swvl1", "swvl2", "tcc" 

that I need to organise fields in this fashion, perhaps.

My input_state includes these additions, and I have no way of knowing which order they should come in for the model:

'longitude', 'latitude', 'z', 'cos_latitude', 'insolation', 'slor', 'time_of_day', 'cos_longitude', 'sin_longitude', 'sin_latitude', 'lsm', 'cos_julian_day', 'cos_local_time', 'sin_local_time', 'sin_julian_day', 'sdor', 'day_of_year'

I’ve had a go at reshaping it all regardless anyway into

(1, 2, 1, grid, variables)

And

device = next(runner.model.parameters()).device
input_tensor = input_tensor.to(device).float()
 
input_tensor = runner.model.pre_processors(input_tensor, in_place=False)
y_hat = runner.model.model(input_tensor)

However, it is complaining that I am an extra eight fields short – I have 107 and the expectation is 115, so there must be some extra fields that runner is including that I am unaware of (maybe land masks and things)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    Status

    To be triaged

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions