Skip to content

Commit ba511bc

Browse files
authored
Merge pull request #1155 from UXARRAY/rajeeja/use_polars_validation
o performance improvements for validation functions
2 parents f5f5a1b + 9af8dc0 commit ba511bc

File tree

1 file changed

+20
-24
lines changed

1 file changed

+20
-24
lines changed

Diff for: uxarray/grid/validation.py

+20-24
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,25 @@
11
import numpy as np
2+
import polars as pl
23
from warnings import warn
34

4-
55
from uxarray.constants import ERROR_TOLERANCE, INT_DTYPE
66

77

8-
# validation helper functions
98
def _check_connectivity(grid):
10-
"""Check if all nodes are referenced by at least one element.
9+
"""Check if all nodes are referenced by at least one element."""
1110

12-
If not, the mesh may have hanging nodes and may not a valid UGRID
13-
mesh
14-
"""
11+
# Convert face_node_connectivity to a Polars Series and get unique values
12+
nodes_in_conn = pl.Series(grid.face_node_connectivity.values.flatten()).unique()
1513

16-
# Check if all nodes are referenced by at least one element
17-
# get unique nodes in connectivity
18-
nodes_in_conn = np.unique(grid.face_node_connectivity.values.flatten())
19-
# remove negative indices/fill values from the list
20-
nodes_in_conn = nodes_in_conn[nodes_in_conn >= 0]
14+
# Filter out negative values
15+
nodes_in_conn = nodes_in_conn.filter(nodes_in_conn >= 0)
2116

22-
# check if the size of unique nodes in connectivity is equal to the number of nodes
23-
if nodes_in_conn.size == grid.n_node:
17+
# Check if the size of unique nodes in connectivity is equal to the number of nodes
18+
if len(nodes_in_conn) == grid.n_node:
2419
return True
2520
else:
2621
warn(
27-
"Some nodes may not be referenced by any element. {0} and {1}".format(
28-
nodes_in_conn.size, grid.n_node
29-
),
22+
f"Some nodes may not be referenced by any element. {len(nodes_in_conn)} and {grid.n_node}",
3023
RuntimeWarning,
3124
)
3225
return False
@@ -35,18 +28,21 @@ def _check_connectivity(grid):
3528
def _check_duplicate_nodes(grid):
3629
"""Check if there are duplicate nodes in the mesh."""
3730

38-
coords = np.vstack([grid.node_lon.values, grid.node_lat.values])
39-
unique_nodes, indices = np.unique(coords, axis=0, return_index=True)
40-
duplicate_indices = np.setdiff1d(np.arange(len(coords)), indices)
31+
# Convert grid to Polars DataFrame
32+
df = pl.DataFrame({"lon": grid.node_lon.values, "lat": grid.node_lat.values})
33+
34+
# Find unique nodes based on 'lon' and 'lat'
35+
unique_df = df.unique(subset=["lon", "lat"], maintain_order=True)
4136

42-
if duplicate_indices.size > 0:
37+
# Find duplicate nodes using an anti-join
38+
duplicate_df = df.join(unique_df, on=["lon", "lat"], how="anti")
39+
40+
# Print duplicate nodes
41+
if not duplicate_df.is_empty():
4342
warn(
44-
"Duplicate nodes found in the mesh. {0} nodes are duplicates.".format(
45-
duplicate_indices.size
46-
),
43+
f"Duplicate nodes found in the mesh. {duplicate_df.shape[0]} nodes are duplicates.",
4744
RuntimeWarning,
4845
)
49-
return False
5046
else:
5147
return True
5248

0 commit comments

Comments
 (0)