@@ -155,12 +155,30 @@ def get_coordinates(self) -> np.ndarray:
155
155
class TensorNodes (BaseNodeBuilder ):
156
156
"""Tensor nodes"""
157
157
158
- def __init__ (self , tensor : torch .Tensor , name : str , lat_idx : tuple [int , int ], lon_idx : tuple [int , int ]) -> None :
158
+ def __init__ (
159
+ self ,
160
+ tensor : torch .Tensor ,
161
+ name : str ,
162
+ lat_idx : tuple [int , int ], # idx of sin(lat) & cos(lat)
163
+ lon_idx : tuple [int , int ], # idx of sin(lon) & cos(lon)
164
+ channel_dim : int = - 1 ,
165
+ ) -> None :
159
166
self .data = tensor
160
167
self .lat_idx = lat_idx
161
168
self .lon_idx = lon_idx
169
+ self .channel_dim = channel_dim
162
170
super ().__init__ (name )
163
171
172
+ def undo_sincos (self ) -> tuple [torch .Tensor , torch .Tensor ]:
173
+ sin_lat = self .data .select (self .channel_dim , self .lat_idx [0 ])
174
+ cos_lat = self .data .select (self .channel_dim , self .lat_idx [1 ])
175
+ sin_lon = self .data .select (self .channel_dim , self .lon_idx [0 ])
176
+ cos_lon = self .data .select (self .channel_dim , self .lon_idx [1 ])
177
+
178
+ latitudes = np .arctan2 (sin_lat , cos_lat )
179
+ longitudes = np .arctan2 (sin_lon , cos_lon )
180
+ return latitudes , longitudes
181
+
164
182
def get_coordinates (self ) -> torch .Tensor :
165
183
"""Get the coordinates of the nodes.
166
184
@@ -169,8 +187,7 @@ def get_coordinates(self) -> torch.Tensor:
169
187
torch.Tensor of shape (num_nodes, 2)
170
188
A 2D tensor with the coordinates, in radians.
171
189
"""
172
- latitudes = np .arctan2 (self .data [self .lat_idx [0 ]], self .data [self .lat_idx [1 ]]) # sin and cos(latitude)
173
- longitudes = np .arctan2 (self .data [self .lon_idx [0 ]], self .data [self .lon_idx [1 ]]) # sin and cos(longitude)
190
+ latitudes , longitudes = self .undo_sincos ()
174
191
return self .reshape_coords (latitudes , longitudes )
175
192
176
193
# def register_attributes(self, graph: HeteroData, config: Optional[DotDict] = None) -> HeteroData:
0 commit comments