Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

Commit 5ad4c19

Browse files
committed
fix: tensor node builder
1 parent 48f24ab commit 5ad4c19

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

src/anemoi/graphs/nodes/builders/from_file.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,30 @@ def get_coordinates(self) -> np.ndarray:
155155
class TensorNodes(BaseNodeBuilder):
156156
"""Tensor nodes"""
157157

158-
def __init__(self, tensor: torch.Tensor, name: str, lat_idx: tuple[int, int], lon_idx: tuple[int, int]) -> None:
158+
def __init__(
159+
self,
160+
tensor: torch.Tensor,
161+
name: str,
162+
lat_idx: tuple[int, int], # idx of sin(lat) & cos(lat)
163+
lon_idx: tuple[int, int], # idx of sin(lon) & cos(lon)
164+
channel_dim: int = -1,
165+
) -> None:
159166
self.data = tensor
160167
self.lat_idx = lat_idx
161168
self.lon_idx = lon_idx
169+
self.channel_dim = channel_dim
162170
super().__init__(name)
163171

172+
def undo_sincos(self) -> tuple[torch.Tensor, torch.Tensor]:
173+
sin_lat = self.data.select(self.channel_dim, self.lat_idx[0])
174+
cos_lat = self.data.select(self.channel_dim, self.lat_idx[1])
175+
sin_lon = self.data.select(self.channel_dim, self.lon_idx[0])
176+
cos_lon = self.data.select(self.channel_dim, self.lon_idx[1])
177+
178+
latitudes = np.arctan2(sin_lat, cos_lat)
179+
longitudes = np.arctan2(sin_lon, cos_lon)
180+
return latitudes, longitudes
181+
164182
def get_coordinates(self) -> torch.Tensor:
165183
"""Get the coordinates of the nodes.
166184
@@ -169,8 +187,7 @@ def get_coordinates(self) -> torch.Tensor:
169187
torch.Tensor of shape (num_nodes, 2)
170188
A 2D tensor with the coordinates, in radians.
171189
"""
172-
latitudes = np.arctan2(self.data[self.lat_idx[0]], self.data[self.lat_idx[1]]) # sin and cos(latitude)
173-
longitudes = np.arctan2(self.data[self.lon_idx[0]], self.data[self.lon_idx[1]]) # sin and cos(longitude)
190+
latitudes, longitudes = self.undo_sincos()
174191
return self.reshape_coords(latitudes, longitudes)
175192

176193
# def register_attributes(self, graph: HeteroData, config: Optional[DotDict] = None) -> HeteroData:

0 commit comments

Comments
 (0)