Skip to content

Commit 9136dc9

Browse files
committed
fixed error in netlist abstraction
1 parent af446e6 commit 9136dc9

File tree

3 files changed

+125
-49
lines changed

3 files changed

+125
-49
lines changed
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
2+
# Example of using build_feature_vec for gate_feature with all available features
3+
4+
from hal_plugins import machine_learning
5+
6+
# Create the feature context with the netlist
7+
fc = machine_learning.Context(netlist)
8+
9+
10+
features = [
11+
#machine_learning.gate_feature.ConnectedGlobalIOs(),
12+
13+
machine_learning.gate_feature.DistanceGlobalIO(hal_py.PinDirection.output, directed=True, forbidden_pin_types=[hal_py.PinType.clock, hal_py.PinType.reset, hal_py.PinType.enable]),
14+
machine_learning.gate_feature.DistanceGlobalIO(hal_py.PinDirection.output, directed=False, forbidden_pin_types=[hal_py.PinType.clock, hal_py.PinType.reset, hal_py.PinType.enable]),
15+
machine_learning.gate_feature.DistanceGlobalIO(hal_py.PinDirection.input, directed=True, forbidden_pin_types=[hal_py.PinType.clock, hal_py.PinType.reset, hal_py.PinType.enable]),
16+
machine_learning.gate_feature.DistanceGlobalIO(hal_py.PinDirection.input, directed=False, forbidden_pin_types=[hal_py.PinType.clock, hal_py.PinType.reset, hal_py.PinType.enable]),
17+
18+
# machine_learning.gate_feature.SequentialDistanceGlobalIO(hal_py.PinDirection.output, directed=True, forbidden_pin_types=[hal_py.PinType.clock, hal_py.PinType.reset, hal_py.PinType.enable]),
19+
# machine_learning.gate_feature.SequentialDistanceGlobalIO(hal_py.PinDirection.output, directed=False, forbidden_pin_types=[hal_py.PinType.clock, hal_py.PinType.reset, hal_py.PinType.enable]),
20+
# machine_learning.gate_feature.SequentialDistanceGlobalIO(hal_py.PinDirection.input, directed=True, forbidden_pin_types=[hal_py.PinType.clock, hal_py.PinType.reset, hal_py.PinType.enable]),
21+
# machine_learning.gate_feature.SequentialDistanceGlobalIO(hal_py.PinDirection.input, directed=False, forbidden_pin_types=[hal_py.PinType.clock, hal_py.PinType.reset, hal_py.PinType.enable]),
22+
23+
# machine_learning.gate_feature.IODegrees(),
24+
25+
# machine_learning.gate_feature.GateTypeOneHot(),
26+
27+
# machine_learning.gate_feature.NeighboringGateTypes(1, hal_py.PinDirection.output, directed=True),
28+
# machine_learning.gate_feature.NeighboringGateTypes(2, hal_py.PinDirection.output, directed=True),
29+
# machine_learning.gate_feature.NeighboringGateTypes(3, hal_py.PinDirection.output, directed=True),
30+
31+
# machine_learning.gate_feature.NeighboringGateTypes(1, hal_py.PinDirection.input, directed=True),
32+
# machine_learning.gate_feature.NeighboringGateTypes(2, hal_py.PinDirection.input, directed=True),
33+
# machine_learning.gate_feature.NeighboringGateTypes(3, hal_py.PinDirection.input, directed=True),
34+
35+
# machine_learning.gate_feature.BetweennessCentrality(directed = True, cutoff=-1),
36+
# machine_learning.gate_feature.BetweennessCentrality(directed = True, cutoff=16),
37+
# machine_learning.gate_feature.BetweennessCentrality(directed = False, cutoff=-1),
38+
# machine_learning.gate_feature.BetweennessCentrality(directed = False, cutoff=16),
39+
# machine_learning.gate_feature.SequentialBetweennessCentrality(directed = True, cutoff=-1),
40+
# machine_learning.gate_feature.SequentialBetweennessCentrality(directed = True, cutoff=16),
41+
# machine_learning.gate_feature.SequentialBetweennessCentrality(directed = False, cutoff=-1),
42+
# machine_learning.gate_feature.SequentialBetweennessCentrality(directed = False, cutoff=16),
43+
44+
# machine_learning.gate_feature.HarmonicCentrality(direction=hal_py.PinDirection.output, cutoff=-1),
45+
# machine_learning.gate_feature.HarmonicCentrality(direction=hal_py.PinDirection.output, cutoff=16),
46+
# machine_learning.gate_feature.HarmonicCentrality(direction=hal_py.PinDirection.inout, cutoff=-1),
47+
# machine_learning.gate_feature.HarmonicCentrality(direction=hal_py.PinDirection.inout, cutoff=16),
48+
# machine_learning.gate_feature.SequentialHarmonicCentrality(direction=hal_py.PinDirection.output, cutoff=-1),
49+
# machine_learning.gate_feature.SequentialHarmonicCentrality(direction=hal_py.PinDirection.output, cutoff=16),
50+
# machine_learning.gate_feature.SequentialHarmonicCentrality(direction=hal_py.PinDirection.inout, cutoff=-1),
51+
# machine_learning.gate_feature.SequentialHarmonicCentrality(direction=hal_py.PinDirection.inout, cutoff=16),
52+
]
53+
54+
gates = [netlist.get_gate_by_id(3)]
55+
56+
# Build the feature vector for the pair of gates
57+
feature_vector = machine_learning.gate_feature.build_feature_vecs(fc, features, gates)
58+
59+
print("Feature vector:", feature_vector)

plugins/machine_learning/src/features/gate_feature_single.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ namespace hal
8181
log_error("machine_learning", "{}", global_io_connections.get_error().get());
8282
return false;
8383
}
84+
85+
// TODO remove debug print
86+
// if (!global_io_connections.get().empty())
87+
// {
88+
// std::cout << "Global IO connections: " << global_io_connections.get().front()->get_name() << std::endl;
89+
// std::cout << "Endpoint: " << ep->get_pin()->get_name() << std::endl;
90+
// }
91+
8492
return !global_io_connections.get().empty();
8593
},
8694
m_direction,
@@ -324,5 +332,5 @@ namespace hal
324332
return OK(feature_vecs);
325333
}
326334
} // namespace gate_feature
327-
} // namespace machine_learning
335+
} // namespace machine_learning
328336
} // namespace hal

src/netlist/decorators/netlist_abstraction_decorator.cpp

Lines changed: 57 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,26 @@ namespace hal
4343
// std::cout << ep_out->get_pin()->get_name() << std::endl;
4444

4545
new_abstraction->m_successors.insert({ep_out, {}});
46+
new_abstraction->m_global_output_successors.insert({ep_out, {}});
47+
4648
const auto successors = nl_trav_dec.get_next_matching_endpoints(
4749
ep_out,
4850
true,
49-
[target_gates_set](const auto& ep) { return ep->is_destination_pin() && target_gates_set.find(ep->get_gate()) != target_gates_set.end(); },
51+
[target_gates_set](const auto& ep) {
52+
bool found_target_gate = ep->is_destination_pin() && target_gates_set.find(ep->get_gate()) != target_gates_set.end();
53+
if (found_target_gate)
54+
{
55+
return true;
56+
}
57+
58+
bool found_global_output = ep->is_source_pin() && ep->get_net()->is_global_output_net();
59+
if (found_global_output)
60+
{
61+
return true;
62+
}
63+
64+
return false;
65+
},
5066
false,
5167
exit_endpoint_filter,
5268
entry_endpoint_filter);
@@ -59,66 +75,59 @@ namespace hal
5975

6076
for (Endpoint* ep : successors.get())
6177
{
62-
new_abstraction->m_successors.at(ep_out).push_back(ep);
63-
}
64-
}
65-
66-
// gather all global output succesors
67-
for (Endpoint* ep_out : gate->get_fan_out_endpoints())
68-
{
69-
new_abstraction->m_global_output_successors.insert({ep_out, {}});
70-
71-
const auto destinations = nl_trav_dec.get_next_matching_endpoints(
72-
ep_out, true, [](const auto& ep) { return ep->is_source_pin() && ep->get_net()->is_global_output_net(); }, false, exit_endpoint_filter, entry_endpoint_filter);
73-
74-
if (destinations.is_error())
75-
{
76-
return ERR_APPEND(destinations.get_error(),
77-
"cannot build netlist abstraction: failed to gather global succesor endpoints for gate " + gate->get_name() + " with ID " + std::to_string(gate->get_id()));
78-
}
79-
80-
for (const auto* ep : destinations.get())
81-
{
82-
new_abstraction->m_global_output_successors.at(ep_out).push_back({ep->get_net()});
78+
if (ep->is_destination_pin())
79+
{
80+
new_abstraction->m_successors.at(ep_out).push_back(ep);
81+
}
82+
else if (ep->is_source_pin())
83+
{
84+
new_abstraction->m_global_output_successors.at(ep_out).push_back(ep->get_net());
85+
}
8386
}
8487
}
8588

8689
// gather all predecessors
8790
for (Endpoint* ep_in : gate->get_fan_in_endpoints())
8891
{
8992
new_abstraction->m_predecessors.insert({ep_in, {}});
93+
new_abstraction->m_global_input_predecessors.insert({ep_in, {}});
9094

91-
const auto predecessors = nl_trav_dec.get_next_matching_endpoints(
92-
ep_in, false, [target_gates_set](const auto& ep) { return ep->is_source_pin() && target_gates_set.find(ep->get_gate()) != target_gates_set.end(); });
93-
94-
if (predecessors.is_error())
95-
{
96-
return ERR_APPEND(predecessors.get_error(),
97-
"cannot build netlist abstraction: failed to gather predecessor endpoints for gate " + gate->get_name() + " with ID " + std::to_string(gate->get_id()));
98-
}
95+
const auto predecessors = nl_trav_dec.get_next_matching_endpoints(ep_in,
96+
false,
97+
[target_gates_set](const auto& ep) {
98+
bool found_target_gate = ep->is_source_pin() && target_gates_set.find(ep->get_gate()) != target_gates_set.end();
99+
if (found_target_gate)
100+
{
101+
return true;
102+
}
99103

100-
for (Endpoint* ep : predecessors.get())
101-
{
102-
new_abstraction->m_predecessors.at(ep_in).push_back(ep);
103-
}
104-
}
104+
bool found_global_input = ep->is_destination_pin() && ep->get_net()->is_global_input_net();
105+
if (found_global_input)
106+
{
107+
return true;
108+
}
105109

106-
// gather all global input predecessors
107-
for (Endpoint* ep_in : gate->get_fan_in_endpoints())
108-
{
109-
new_abstraction->m_global_input_predecessors.insert({ep_in, {}});
110+
return false;
111+
}
110112

111-
const auto predecessors = nl_trav_dec.get_next_matching_endpoints(ep_in, false, [](const auto& ep) { return ep->is_destination_pin() && ep->get_net()->is_global_input_net(); });
113+
);
112114

113115
if (predecessors.is_error())
114116
{
115117
return ERR_APPEND(predecessors.get_error(),
116-
"cannot build netlist abstraction: failed to gather global predecessor endpoints for gate " + gate->get_name() + " with ID " + std::to_string(gate->get_id()));
118+
"cannot build netlist abstraction: failed to gather predecessor endpoints for gate " + gate->get_name() + " with ID " + std::to_string(gate->get_id()));
117119
}
118120

119-
for (const auto* ep : predecessors.get())
121+
for (Endpoint* ep : predecessors.get())
120122
{
121-
new_abstraction->m_global_input_predecessors.at(ep_in).push_back({ep->get_net()});
123+
if (ep->is_source_pin())
124+
{
125+
new_abstraction->m_predecessors.at(ep_in).push_back(ep);
126+
}
127+
else if (ep->is_destination_pin())
128+
{
129+
new_abstraction->m_global_input_predecessors.at(ep_in).push_back(ep->get_net());
130+
}
122131
}
123132
}
124133
}
@@ -396,11 +405,11 @@ namespace hal
396405
for (const auto& exit_ep : current)
397406
{
398407
// currently only works for input and output pins
399-
if (exit_ep->get_pin()->get_direction() != PinDirection::output && exit_ep->get_pin()->get_direction() != PinDirection::input)
400-
{
401-
return ERR("failed to get shortest path distance: found endpoint at gate " + exit_ep->get_gate()->get_name() + " with ID " + std::to_string(exit_ep->get_gate()->get_id())
402-
+ " and pin " + exit_ep->get_pin()->get_name() + " with direction " + enum_to_string(exit_ep->get_pin()->get_direction()) + " that is currently unhandled");
403-
}
408+
// if (exit_ep->get_pin()->get_direction() != PinDirection::output && exit_ep->get_pin()->get_direction() != PinDirection::input)
409+
// {
410+
// return ERR("failed to get shortest path distance: found endpoint at gate " + exit_ep->get_gate()->get_name() + " with ID " + std::to_string(exit_ep->get_gate()->get_id())
411+
// + " and pin " + exit_ep->get_pin()->get_name() + " with direction " + enum_to_string(exit_ep->get_pin()->get_direction()) + " that is currently unhandled");
412+
// }
404413

405414
const auto entry_eps = (exit_ep->get_pin()->get_direction() == PinDirection::output) ? m_abstraction.get_successors(exit_ep) : m_abstraction.get_predecessors(exit_ep);
406415
if (entry_eps.is_error())

0 commit comments

Comments
 (0)