1
1
import numpy as np
2
+ import polars as pl
2
3
from warnings import warn
3
4
4
-
5
5
from uxarray .constants import ERROR_TOLERANCE , INT_DTYPE
6
6
7
7
8
- # validation helper functions
9
8
def _check_connectivity (grid ):
10
- """Check if all nodes are referenced by at least one element.
9
+ """Check if all nodes are referenced by at least one element."""
11
10
12
- If not, the mesh may have hanging nodes and may not a valid UGRID
13
- mesh
14
- """
11
+ # Convert face_node_connectivity to a Polars Series and get unique values
12
+ nodes_in_conn = pl .Series (grid .face_node_connectivity .values .flatten ()).unique ()
15
13
16
- # Check if all nodes are referenced by at least one element
17
- # get unique nodes in connectivity
18
- nodes_in_conn = np .unique (grid .face_node_connectivity .values .flatten ())
19
- # remove negative indices/fill values from the list
20
- nodes_in_conn = nodes_in_conn [nodes_in_conn >= 0 ]
14
+ # Filter out negative values
15
+ nodes_in_conn = nodes_in_conn .filter (nodes_in_conn >= 0 )
21
16
22
- # check if the size of unique nodes in connectivity is equal to the number of nodes
23
- if nodes_in_conn . size == grid .n_node :
17
+ # Check if the size of unique nodes in connectivity is equal to the number of nodes
18
+ if len ( nodes_in_conn ) == grid .n_node :
24
19
return True
25
20
else :
26
21
warn (
27
- "Some nodes may not be referenced by any element. {0} and {1}" .format (
28
- nodes_in_conn .size , grid .n_node
29
- ),
22
+ f"Some nodes may not be referenced by any element. { len (nodes_in_conn )} and { grid .n_node } " ,
30
23
RuntimeWarning ,
31
24
)
32
25
return False
@@ -35,18 +28,21 @@ def _check_connectivity(grid):
35
28
def _check_duplicate_nodes (grid ):
36
29
"""Check if there are duplicate nodes in the mesh."""
37
30
38
- coords = np .vstack ([grid .node_lon .values , grid .node_lat .values ])
39
- unique_nodes , indices = np .unique (coords , axis = 0 , return_index = True )
40
- duplicate_indices = np .setdiff1d (np .arange (len (coords )), indices )
31
+ # Convert grid to Polars DataFrame
32
+ df = pl .DataFrame ({"lon" : grid .node_lon .values , "lat" : grid .node_lat .values })
33
+
34
+ # Find unique nodes based on 'lon' and 'lat'
35
+ unique_df = df .unique (subset = ["lon" , "lat" ], maintain_order = True )
41
36
42
- if duplicate_indices .size > 0 :
37
+ # Find duplicate nodes using an anti-join
38
+ duplicate_df = df .join (unique_df , on = ["lon" , "lat" ], how = "anti" )
39
+
40
+ # Print duplicate nodes
41
+ if not duplicate_df .is_empty ():
43
42
warn (
44
- "Duplicate nodes found in the mesh. {0} nodes are duplicates." .format (
45
- duplicate_indices .size
46
- ),
43
+ f"Duplicate nodes found in the mesh. { duplicate_df .shape [0 ]} nodes are duplicates." ,
47
44
RuntimeWarning ,
48
45
)
49
- return False
50
46
else :
51
47
return True
52
48
0 commit comments