Skip to content

Commit 20174e6

Browse files
committed
small optimizations to netlist abstraction and made machine learning context thread safe
1 parent d2cc1ce commit 20174e6

File tree

8 files changed

+131
-77
lines changed

8 files changed

+131
-77
lines changed

include/hal_core/netlist/decorators/netlist_abstraction_decorator.h

+7-5
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ namespace hal
4747
struct NETLIST_API NetlistAbstraction
4848
{
4949
public:
50+
NetlistAbstraction(NetlistAbstraction&& other) = default;
51+
5052
/**
5153
* @brief Creates a `NetlistAbstraction` from a set of gates.
5254
*
@@ -56,11 +58,11 @@ namespace hal
5658
* @param[in] exit_endpoint_filter - Filter condition to stop traversal on a fan-in/out endpoint.
5759
* @param[in] entry_endpoint_filter - Filter condition to stop traversal on a successor/predecessor endpoint.
5860
*/
59-
static Result<NetlistAbstraction> create(const Netlist* netlist,
60-
const std::vector<Gate*>& gates,
61-
const bool include_all_netlist_gates = false,
62-
const std::function<bool(const Endpoint*, const u32 current_depth)>& exit_endpoint_filter = nullptr,
63-
const std::function<bool(const Endpoint*, const u32 current_depth)>& entry_endpoint_filter = nullptr);
61+
static Result<std::shared_ptr<NetlistAbstraction>> create(const Netlist* netlist,
62+
const std::vector<Gate*>& gates,
63+
const bool include_all_netlist_gates = false,
64+
const std::function<bool(const Endpoint*, const u32 current_depth)>& exit_endpoint_filter = nullptr,
65+
const std::function<bool(const Endpoint*, const u32 current_depth)>& entry_endpoint_filter = nullptr);
6466

6567
/**
6668
* @brief Gets the predecessors of a gate within the abstraction.

plugins/machine_learning/include/machine_learning/types.h

+13-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
#include "hal_core/defines.h"
44
#include "hal_core/netlist/decorators/netlist_abstraction_decorator.h"
55

6+
#include <atomic>
7+
#include <memory>
8+
#include <mutex>
69
#include <optional>
710
#include <vector>
811

@@ -52,10 +55,16 @@ namespace hal
5255
const u32 num_threads;
5356

5457
private:
55-
std::optional<NetlistAbstraction> m_sequential_abstraction;
56-
std::optional<NetlistAbstraction> m_original_abstraction;
57-
std::optional<std::vector<GateTypeProperty>> m_possible_gate_type_properties;
58-
std::optional<MultiBitInformation> m_mbi;
58+
std::shared_ptr<MultiBitInformation> m_mbi{nullptr};
59+
std::shared_ptr<NetlistAbstraction> m_sequential_abstraction{nullptr};
60+
std::shared_ptr<NetlistAbstraction> m_original_abstraction{nullptr};
61+
std::shared_ptr<std::vector<GateTypeProperty>> m_possible_gate_type_properties{nullptr};
62+
63+
// Mutexes for thread-safe initialization
64+
std::mutex m_mbi_mutex;
65+
std::mutex m_sequential_abstraction_mutex;
66+
std::mutex m_original_abstraction_mutex;
67+
std::mutex m_possible_gate_type_properties_mutex;
5968
};
6069

6170
enum GraphDirection

plugins/machine_learning/src/features/gate_feature.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#include <vector>
1111

1212
#define MAX_DISTANCE 255
13-
#define PROGRESS_BAR
13+
// #define PROGRESS_BAR
1414

1515
namespace hal
1616
{

plugins/machine_learning/src/features/gate_pair_feature.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#include "machine_learning/features/gate_pair_feature.h"
88

99
#define MAX_DISTANCE 255
10-
#define PROGRESS_BAR
10+
// #define PROGRESS_BAR
1111

1212
namespace hal
1313
{

plugins/machine_learning/src/graph_neural_network.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ namespace hal
8989
{
9090
return ERR_APPEND(sequential_abstraction_res.get_error(), "cannot get sequential netlist abstraction for gate feature context: failed to build abstraction.");
9191
}
92-
const auto sequential_abstraction = sequential_abstraction_res.get();
92+
const auto& sequential_abstraction = sequential_abstraction_res.get();
9393

9494
// edge list
9595
std::vector<u32> sources;
@@ -100,7 +100,7 @@ namespace hal
100100
const u32 g_idx = gate_to_idx.at(g);
101101
if (dir == GraphDirection::directed)
102102
{
103-
const auto unique_predecessors = sequential_abstraction.get_unique_predecessors(g);
103+
const auto unique_predecessors = sequential_abstraction->get_unique_predecessors(g);
104104
if (unique_predecessors.is_error())
105105
{
106106
return ERR_APPEND(unique_predecessors.get_error(),
@@ -115,7 +115,7 @@ namespace hal
115115

116116
if (dir == GraphDirection::undirected)
117117
{
118-
const auto unique_successors = sequential_abstraction.get_unique_successors(g);
118+
const auto unique_successors = sequential_abstraction->get_unique_successors(g);
119119
if (unique_successors.is_error())
120120
{
121121
return ERR_APPEND(unique_successors.get_error(),

plugins/machine_learning/src/types.cpp

+84-46
Original file line numberDiff line numberDiff line change
@@ -272,96 +272,134 @@ namespace hal
272272

273273
const MultiBitInformation& Context::get_multi_bit_information()
274274
{
275-
if (!m_mbi.has_value())
275+
auto mbi = std::atomic_load_explicit(&m_mbi, std::memory_order_acquire);
276+
if (mbi)
276277
{
277-
const auto seq_gates = nl->get_gates([](const auto* g) { return g->get_type()->has_property(GateTypeProperty::sequential); });
278-
m_mbi = calculate_multi_bit_information(seq_gates);
278+
return *mbi;
279279
}
280+
else
281+
{
282+
std::lock_guard<std::mutex> lock(m_mbi_mutex);
283+
mbi = std::atomic_load_explicit(&m_mbi, std::memory_order_acquire);
284+
if (mbi)
285+
{
286+
return *mbi;
287+
}
280288

281-
return m_mbi.value();
289+
auto new_mbi = std::make_shared<MultiBitInformation>();
290+
const auto seq_gates = nl->get_gates([](const auto* g) { return g->get_type()->has_property(GateTypeProperty::sequential); });
291+
*new_mbi = calculate_multi_bit_information(seq_gates);
292+
293+
std::atomic_store_explicit(&m_mbi, new_mbi, std::memory_order_release);
294+
295+
return *new_mbi;
296+
}
282297
}
283298

284299
const Result<NetlistAbstraction*> Context::get_sequential_abstraction()
285300
{
286-
if (!m_sequential_abstraction.has_value())
301+
auto abstraction = std::atomic_load_explicit(&m_sequential_abstraction, std::memory_order_acquire);
302+
if (abstraction)
287303
{
304+
return OK(abstraction.get());
305+
}
306+
else
307+
{
308+
std::lock_guard<std::mutex> lock(m_sequential_abstraction_mutex);
309+
// Double-check after acquiring the lock
310+
abstraction = std::atomic_load_explicit(&m_sequential_abstraction, std::memory_order_acquire);
311+
if (abstraction)
312+
{
313+
return OK(abstraction.get());
314+
}
315+
288316
const auto seq_gates = nl->get_gates([](const auto* g) { return g->get_type()->has_property(GateTypeProperty::sequential); });
289317

290-
const std::vector<PinType> forbidden_pins = {
291-
PinType::clock, /*PinType::done, PinType::error, PinType::error_detection,*/ /*PinType::none,*/ PinType::ground, PinType::power /*, PinType::status*/};
318+
const std::vector<PinType> forbidden_pins = {PinType::clock, PinType::ground, PinType::power};
292319

293-
const auto endpoint_filter = [forbidden_pins](const auto* ep, const auto& _d) {
294-
UNUSED(_d);
320+
const auto endpoint_filter = [forbidden_pins](const auto* ep, const auto&) {
295321
return std::find(forbidden_pins.begin(), forbidden_pins.end(), ep->get_pin()->get_type()) == forbidden_pins.end();
296322
};
297323

298-
const auto sequential_abstraction_res = NetlistAbstraction::create(nl, seq_gates, true, endpoint_filter, endpoint_filter);
324+
auto sequential_abstraction_res = NetlistAbstraction::create(nl, seq_gates, true, endpoint_filter, endpoint_filter);
299325
if (sequential_abstraction_res.is_error())
300326
{
301-
return ERR_APPEND(sequential_abstraction_res.get_error(), "cannot get sequential netlist abstraction for gate feature context: failed to build abstraction.");
327+
return ERR_APPEND(sequential_abstraction_res.get_error(), "Cannot get sequential netlist abstraction: failed to build abstraction.");
302328
}
303329

304-
m_sequential_abstraction = sequential_abstraction_res.get();
330+
auto new_abstraction = sequential_abstraction_res.get();
305331

306-
// TODO remove debug print
307-
// std::cout << "Built abstraction" << std::endl;
308-
}
332+
std::atomic_store_explicit(&m_sequential_abstraction, new_abstraction, std::memory_order_release);
309333

310-
return OK(&m_sequential_abstraction.value());
334+
return OK(m_sequential_abstraction.get());
335+
}
311336
}
312337

313338
const Result<NetlistAbstraction*> Context::get_original_abstraction()
314339
{
315-
if (!m_original_abstraction.has_value())
340+
auto abstraction = std::atomic_load_explicit(&m_original_abstraction, std::memory_order_acquire);
341+
if (abstraction)
316342
{
317-
// const std::vector<PinType> forbidden_pins = {
318-
// PinType::clock, /*PinType::done, PinType::error, PinType::error_detection,*/ /*PinType::none,*/ PinType::ground, PinType::power /*, PinType::status*/};
319-
320-
// const auto endpoint_filter = [forbidden_pins](const auto* ep, const auto& _d) {
321-
// UNUSED(_d);
322-
// return std::find(forbidden_pins.begin(), forbidden_pins.end(), ep->get_pin()->get_type()) == forbidden_pins.end();
323-
// };
343+
return OK(abstraction.get());
344+
}
345+
else
346+
{
347+
std::lock_guard<std::mutex> lock(m_original_abstraction_mutex);
348+
// Double-check after acquiring the lock
349+
abstraction = std::atomic_load_explicit(&m_original_abstraction, std::memory_order_acquire);
350+
if (abstraction)
351+
{
352+
return OK(abstraction.get());
353+
}
324354

325-
const auto original_abstraction_res = NetlistAbstraction::create(nl, nl->get_gates(), true, nullptr, nullptr);
355+
auto original_abstraction_res = NetlistAbstraction::create(nl, nl->get_gates(), true, nullptr, nullptr);
326356
if (original_abstraction_res.is_error())
327357
{
328-
return ERR_APPEND(original_abstraction_res.get_error(), "cannot get original netlist abstraction for gate feature context: failed to build abstraction.");
358+
return ERR_APPEND(original_abstraction_res.get_error(), "Cannot get original netlist abstraction: failed to build abstraction.");
329359
}
330360

331-
m_original_abstraction = original_abstraction_res.get();
361+
auto new_abstraction = original_abstraction_res.get();
332362

333-
// TODO remove debug print
334-
// std::cout << "Built abstraction" << std::endl;
335-
}
363+
std::atomic_store_explicit(&m_original_abstraction, new_abstraction, std::memory_order_release);
336364

337-
return OK(&m_original_abstraction.value());
365+
return OK(m_original_abstraction.get());
366+
}
338367
}
339368

340369
const std::vector<GateTypeProperty>& Context::get_possible_gate_type_properties()
341370
{
342-
if (!m_possible_gate_type_properties.has_value())
371+
auto properties = std::atomic_load_explicit(&m_possible_gate_type_properties, std::memory_order_acquire);
372+
if (properties)
373+
{
374+
return *properties;
375+
}
376+
else
343377
{
344-
std::set<GateTypeProperty> properties;
378+
std::lock_guard<std::mutex> lock(m_possible_gate_type_properties_mutex);
379+
// Double-check after acquiring the lock
380+
properties = std::atomic_load_explicit(&m_possible_gate_type_properties, std::memory_order_acquire);
381+
if (properties)
382+
{
383+
return *properties;
384+
}
385+
386+
std::set<GateTypeProperty> property_set;
345387

346388
for (const auto& [_name, gt] : nl->get_gate_library()->get_gate_types())
347389
{
348-
const auto gt_properties = gt->get_properties();
349-
properties.insert(gt_properties.begin(), gt_properties.end());
390+
const auto& gt_properties = gt->get_properties();
391+
property_set.insert(gt_properties.begin(), gt_properties.end());
350392
}
351393

352-
// for (auto& [gtp, _name] : EnumStrings<GateTypeProperty>::data)
353-
// {
354-
// UNUSED(_name);
355-
// properties.insert(gtp);
356-
// }
394+
auto properties_vec = std::make_shared<std::vector<GateTypeProperty>>(property_set.begin(), property_set.end());
357395

358-
auto properties_vec = utils::to_vector(properties);
359-
// sort alphabetically
360-
std::sort(properties_vec.begin(), properties_vec.end(), [](const auto& a, const auto& b) { return enum_to_string(a) < enum_to_string(b); });
361-
m_possible_gate_type_properties = properties_vec;
362-
}
396+
// Sort alphabetically
397+
std::sort(properties_vec->begin(), properties_vec->end(), [](const auto& a, const auto& b) { return enum_to_string(a) < enum_to_string(b); });
363398

364-
return m_possible_gate_type_properties.value();
399+
std::atomic_store_explicit(&m_possible_gate_type_properties, properties_vec, std::memory_order_release);
400+
401+
return *properties_vec;
402+
}
365403
}
366404
} // namespace machine_learning
367405
} // namespace hal

src/netlist/decorators/netlist_abstraction_decorator.cpp

+21-16
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,25 @@
77

88
namespace hal
99
{
10-
Result<NetlistAbstraction> NetlistAbstraction::create(const Netlist* netlist,
11-
const std::vector<Gate*>& gates,
12-
const bool include_all_netlist_gates,
13-
const std::function<bool(const Endpoint*, const u32 current_depth)>& exit_endpoint_filter,
14-
const std::function<bool(const Endpoint*, const u32 current_depth)>& entry_endpoint_filter)
10+
Result<std::shared_ptr<NetlistAbstraction>> NetlistAbstraction::create(const Netlist* netlist,
11+
const std::vector<Gate*>& gates,
12+
const bool include_all_netlist_gates,
13+
const std::function<bool(const Endpoint*, const u32 current_depth)>& exit_endpoint_filter,
14+
const std::function<bool(const Endpoint*, const u32 current_depth)>& entry_endpoint_filter)
1515
{
1616
const auto nl_trav_dec = NetlistTraversalDecorator(*netlist);
1717

1818
// transform gates into set to check fast if a gate is part of abstraction
19-
const auto gates_set = utils::to_unordered_set(gates);
19+
const auto gates_set = utils::to_unordered_set(gates);
20+
const auto& included_gates = include_all_netlist_gates ? netlist->get_gates() : gates;
2021

21-
auto new_abstraction = NetlistAbstraction();
22+
auto new_abstraction = std::shared_ptr<NetlistAbstraction>(new NetlistAbstraction());
23+
const u32 approximated_endpoint_count = included_gates.size() * 8;
24+
new_abstraction->m_successors.reserve(approximated_endpoint_count);
25+
new_abstraction->m_predecessors.reserve(approximated_endpoint_count);
26+
new_abstraction->m_global_output_successors.reserve(approximated_endpoint_count);
27+
new_abstraction->m_global_input_predecessors.reserve(approximated_endpoint_count);
2228

23-
const auto& included_gates = include_all_netlist_gates ? netlist->get_gates() : gates;
2429
for (const Gate* gate : included_gates)
2530
{
2631
// TODO remove debug print
@@ -29,7 +34,7 @@ namespace hal
2934
// gather all successors
3035
for (Endpoint* ep_out : gate->get_fan_out_endpoints())
3136
{
32-
new_abstraction.m_successors.insert({ep_out, {}});
37+
new_abstraction->m_successors.insert({ep_out, {}});
3338
const auto successors = nl_trav_dec.get_next_matching_endpoints(
3439
ep_out,
3540
true,
@@ -46,14 +51,14 @@ namespace hal
4651

4752
for (Endpoint* ep : successors.get())
4853
{
49-
new_abstraction.m_successors.at(ep_out).push_back(ep);
54+
new_abstraction->m_successors.at(ep_out).push_back(ep);
5055
}
5156
}
5257

5358
// gather all global output succesors
5459
for (Endpoint* ep_out : gate->get_fan_out_endpoints())
5560
{
56-
new_abstraction.m_global_output_successors.insert({ep_out, {}});
61+
new_abstraction->m_global_output_successors.insert({ep_out, {}});
5762

5863
const auto destinations = nl_trav_dec.get_next_matching_endpoints(
5964
ep_out, true, [](const auto& ep) { return ep->is_source_pin() && ep->get_net()->is_global_output_net(); }, false, exit_endpoint_filter, entry_endpoint_filter);
@@ -66,14 +71,14 @@ namespace hal
6671

6772
for (const auto* ep : destinations.get())
6873
{
69-
new_abstraction.m_global_output_successors.at(ep_out).push_back({ep->get_net()});
74+
new_abstraction->m_global_output_successors.at(ep_out).push_back({ep->get_net()});
7075
}
7176
}
7277

7378
// gather all predecessors
7479
for (Endpoint* ep_in : gate->get_fan_in_endpoints())
7580
{
76-
new_abstraction.m_predecessors.insert({ep_in, {}});
81+
new_abstraction->m_predecessors.insert({ep_in, {}});
7782

7883
const auto predecessors =
7984
nl_trav_dec.get_next_matching_endpoints(ep_in, false, [gates_set](const auto& ep) { return ep->is_source_pin() && gates_set.find(ep->get_gate()) != gates_set.end(); });
@@ -86,14 +91,14 @@ namespace hal
8691

8792
for (Endpoint* ep : predecessors.get())
8893
{
89-
new_abstraction.m_predecessors.at(ep_in).push_back(ep);
94+
new_abstraction->m_predecessors.at(ep_in).push_back(ep);
9095
}
9196
}
9297

9398
// gather all global input predecessors
9499
for (Endpoint* ep_in : gate->get_fan_in_endpoints())
95100
{
96-
new_abstraction.m_global_input_predecessors.insert({ep_in, {}});
101+
new_abstraction->m_global_input_predecessors.insert({ep_in, {}});
97102

98103
const auto predecessors = nl_trav_dec.get_next_matching_endpoints(ep_in, false, [](const auto& ep) { return ep->is_destination_pin() && ep->get_net()->is_global_input_net(); });
99104

@@ -105,7 +110,7 @@ namespace hal
105110

106111
for (const auto* ep : predecessors.get())
107112
{
108-
new_abstraction.m_global_input_predecessors.at(ep_in).push_back({ep->get_net()});
113+
new_abstraction->m_global_input_predecessors.at(ep_in).push_back({ep->get_net()});
109114
}
110115
}
111116
}

0 commit comments

Comments
 (0)