-
Notifications
You must be signed in to change notification settings - Fork 28
Description
With respect to the model agnostic interpretability tool I am building (which we have working on Aurora and FastNet currently). We would like to try it with AIFS, however the example I have through Anemoi is proving difficult to unpick in a manner that I might get a pytorch,nn,module inference model object registered with a library called Zennit (it has to be Zennit as we have extended this library significantly), and then do a simple forward pass on the model with some “input state”. This seems like a relatively innocuous piece of code that can be done in other ways with Anemoi – however I am using the pytorch autograd ‘under the hood’ and registering hooks so that I might rewrite rules on backward passes so I need to explicitly do it in the following way (pseudo code follows):
zennit_comp.register(runner.model.model)
# Forward pass
output = runner.model.model.forward(input_state)
# Choose a target tensor for backward (example: output["fields"]["2t"])
output_tensor = output["fields"]["2t"]
output_tensor.retain_grad()
output_tensor.grad = None
# Backward pass
output_tensor.backward(torch.ones_like(output_tensor))
# Remove Zennit hooks
zennit_comp.remove()
It is getting the ‘input_state’ part from Anemoi that I am struggling with.
This is as far as I got - I worked out from doing:
for state in runner.run(input_state=input_state, lead_time=12):
print(state)
"q_50", "q_100", "q_150", "q_200", "q_250", "q_300", "q_400", "q_500", "q_600", "q_700", "q_850", "q_925", "q_1000",
"t_50", "t_100", "t_150", "t_200", "t_250", "t_300", "t_400", "t_500", "t_600", "t_700", "t_850", "t_925", "t_1000",
"u_50", "u_100", "u_150", "u_200", "u_250", "u_300", "u_400", "u_500", "u_600", "u_700", "u_850", "u_925", "u_1000",
"v_50", "v_100", "v_150", "v_200", "v_250", "v_300", "v_400", "v_500", "v_600", "v_700", "v_850", "v_925", "v_1000",
"w_50", "w_100", "w_150", "w_200", "w_250", "w_300", "w_400", "w_500", "w_600", "w_700", "w_850", "w_925", "w_1000",
"z_50", "z_100", "z_150", "z_200", "z_250", "z_300", "z_400", "z_500", "z_600", "z_700", "z_850", "z_925", "z_1000",
"10u", "10v", "2d", "2t", "msl", "skt", "sp", "tcw", "cp", "tp", "100u", "100v", "hcc", "lcc", "mcc", "ro", "sf", "ssrd",
"stl1", "stl2", "strd", "swvl1", "swvl2", "tcc"
that I need to organise fields in this fashion, perhaps.
My input_state includes these additions, and I have no way of knowing which order they should come in for the model:
'longitude', 'latitude', 'z', 'cos_latitude', 'insolation', 'slor', 'time_of_day', 'cos_longitude', 'sin_longitude', 'sin_latitude', 'lsm', 'cos_julian_day', 'cos_local_time', 'sin_local_time', 'sin_julian_day', 'sdor', 'day_of_year'
I’ve had a go at reshaping it all regardless anyway into
(1, 2, 1, grid, variables)
And
device = next(runner.model.parameters()).device
input_tensor = input_tensor.to(device).float()
input_tensor = runner.model.pre_processors(input_tensor, in_place=False)
y_hat = runner.model.model(input_tensor)
However, it is complaining that I am an extra eight fields short – I have 107 and the expectation is 115, so there must be some extra fields that runner is including that I am unaware of (maybe land masks and things)
Metadata
Metadata
Assignees
Labels
Type
Projects
Status