Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Rule conversion for EFDT and VFDT trees #13

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 71 additions & 10 deletions tree_diff/tree_ruleset_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import river
from collections import deque

visit = {}
expected = {}


def find_to_condition(visited, nodes):
"""
This function finds the conditions associated with a given node in a tree and stores them in a dictionary."
"""
if nodes.is_root():
return None
else:
Expand All @@ -17,7 +19,6 @@ def find_to_condition(visited, nodes):
index = i
if index < 0:
raise ValueError("Incorrect tree")
# print(nodes.label)
if len(nodes.children) == 0:
visited["{}:{}".format(node.parent, nodes)] = [
nodes.parent.conditions[index],
Expand All @@ -31,6 +32,7 @@ def find_to_condition(visited, nodes):


def traverse(tree, visited):
visit = {}
for child in tree:
visit = find_to_condition(visited, child)
if len(child.children) > 0:
Expand All @@ -39,6 +41,12 @@ def traverse(tree, visited):


def link_dict_keys(d):
"""
This function takes the tree nodes and it's children to create a link between all of it's children.
It combines the conditions of the parent, child, and grandchild nodes into a single value.
Returns: A dictionary where keys are in the format "parent:child:grandchild" and
values are the sum of the conditions of the parent, child, and grandchild nodes.
"""
linked_dict = {}
result = {}
for key, value in d.items():
Expand Down Expand Up @@ -66,9 +74,16 @@ def link_dict_keys(d):
def tuple_tree_conversion(tree):
visited = {}
ruleset = []
expected = link_dict_keys(traverse(tree.children, visited))
for val in expected.values():
ruleset.append(Rule(val[-1], val[0:-1]))
if len(tree.children) == 0: # Check if it's a root node
attr_name = 'root_node_tree'
visited["root"] = [f"{attr_name} <= 0", tree.label]
antecedent = visited["root"][0]
ruleset.append(Rule(visited['root'][1],[f"{antecedent}"])) # Force root node to have a (antecedent) Rule and label
return Ruleset(ruleset)
else:
expected = link_dict_keys(traverse(tree.children, visited)) # Traverse for each children and find linkage
for val in expected.values():
ruleset.append(Rule(val[-1], val[0:-1]))
return Ruleset(ruleset)


Expand Down Expand Up @@ -108,29 +123,75 @@ def river_is_leaf(node):
return node.n_leaves == 1


def river_return_condition(node):
def river_return_condition(node,path,val_sum):
if isinstance(node, river.tree.nodes.efdtc_nodes.EFDTNumericBinaryBranch):
return Condition(f"attr_{node.feature}", Operator.LE, node.threshold)
weight_value = {}
for elements in range(len(path)):
current = []
for key, value in path[elements].stats.items():
current.append(value)
weight_value[elements] = sum(current) # Store the current path status
all_values = list(weight_value.values())
all_values.append(val_sum)
left,right = node.children
if left.total_weight in all_values: # Examine if either of children's weight matches with parent weight
operator = Operator.LE
elif right.total_weight in all_values:
operator = Operator.GT
return Condition(f"attr_{node.feature}", operator, node.threshold)

elif isinstance(node, tuple): # Multinomial
feature = node[0].feature
threshold = node[0]._r_mapping[node[1]]
return Condition(f"attr_{feature}", Operator.EQ, threshold)

elif isinstance(node, river.tree.nodes.efdtc_nodes.NumericBinaryBranch):
weight_value = {}
for elements in range(len(path)):
current = []
keys_index = []
for key, value in path[elements].stats.items():
current.append(value)
keys_index.append(key)
weight_value[elements] = sum(current) # Store the current path status
all_values = list(weight_value.values())
all_values.append(val_sum)
left,right = node.children
if left.total_weight in all_values: # Examine if either of children's weight matches with parent weight
operator = Operator.LE
elif right.total_weight in all_values:
operator = Operator.GT
else:
for elements in range(len(path)):
if left == path[elements]:
operator = Operator.LE
elif right == path[elements]:
operator = Operator.GT
return Condition(f"attr_{node.feature}", operator, node.threshold)
else:
raise ValueError(node)


def river_create_conditions(path_conds):
return [river_return_condition(c) for c in path_conds]
def river_create_conditions(path_conds,val_sum):
return [river_return_condition(c, path_conds, val_sum) for i, c in enumerate(path_conds)]


def river_create_rule(path):
a = path[-1].stats
weight = []
for key,values in a.items(): # Storing the original weight of parent node
weight.append(values)
val_sum = sum(weight)
m = (None, 0)
for k, v in a.items():
if not m or m[1] < v:
m = (k, v)
label = m[0]
return Rule(conditions=river_create_conditions(path[0:-1]), label=f"{label}")
if len(path) == 1: # Check if the node is a root node
random_attr = "rand"
return Rule(conditions=[Condition(f"attr_{random_attr}", Operator.LE, 0)], label=f"{label}")
else:
return Rule(conditions=river_create_conditions(path[0:-1],val_sum), label=f"{label}")


def river_extract_rules(tree, children, is_leaf):
Expand Down