@@ -72,6 +72,7 @@ def get_values(self, prefix: str) -> List[Tuple[str, Any]]:
72
72
if m > mask :
73
73
continue
74
74
75
+ # make sure this ip is within the network defined by the mask
75
76
if self .ip == (ip & get_subnet_mask (m , v6 )):
76
77
ip_str = f"{ ip_itoa (self .ip , v6 )} /{ m } "
77
78
result .append ((ip_str , self .masks [m ]))
@@ -325,7 +326,7 @@ def find_all(self, prefix: str, children: bool=False) -> List[Tuple[str, Any]]:
325
326
"""
326
327
self .validate_ip_type_for_trie (prefix )
327
328
result = []
328
- ip , _ = cidr_atoi (prefix )
329
+ ip , mask = cidr_atoi (prefix )
329
330
330
331
# for each node on the way down
331
332
last_node = None
@@ -338,7 +339,13 @@ def find_all(self, prefix: str, children: bool=False) -> List[Tuple[str, Any]]:
338
339
if children and last_node .ip == ip :
339
340
# for each child node underneath the last found node
340
341
for node in self .traverse_inorder_from_node (last_node ):
341
- result += node .get_child_values (prefix )
342
+ # skip the first node, as it's already in the result
343
+ if node .ip == last_node .ip :
344
+ continue
345
+
346
+ # make sure this child is within the prefix
347
+ if (node .ip & get_subnet_mask (mask , self .v6 )) == (ip & get_subnet_mask (mask , self .v6 )):
348
+ result += node .get_child_values (prefix )
342
349
343
350
return result
344
351
0 commit comments