Skip to content

Commit dc7939f

Browse files
authored
flow-control: Add cache node optimization. (#8115)
1 parent 88e71f0 commit dc7939f

File tree

4 files changed

+82
-20
lines changed

4 files changed

+82
-20
lines changed

crates/cairo-lang-lowering/src/lower/flow_control/create_graph.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
use cache::Cache;
12
use cairo_lang_semantic::{self as semantic, Condition, PatternId};
23
use cairo_lang_syntax::node::TypedStablePtr;
4+
use filtered_patterns::IndexAndBindings;
35
use itertools::Itertools;
46
use patterns::{CreateNodeParams, create_node_for_patterns, get_pattern};
57

@@ -9,6 +11,7 @@ use super::graph::{
911
};
1012
use crate::lower::context::LoweringContext;
1113

14+
mod cache;
1215
mod filtered_patterns;
1316
mod patterns;
1417

@@ -61,6 +64,8 @@ pub fn create_graph_expr_if<'db>(
6164
let expr_location = ctx.get_location(expr.stable_ptr().untyped());
6265
let expr_var = graph.new_var(expr.ty(), expr_location);
6366

67+
let mut cache = Cache::default();
68+
6469
let match_node_id = create_node_for_patterns(
6570
CreateNodeParams {
6671
ctx,
@@ -69,9 +74,15 @@ pub fn create_graph_expr_if<'db>(
6974
.iter()
7075
.map(|pattern| Some(get_pattern(ctx, *pattern)))
7176
.collect_vec(),
72-
build_node_callback: &|graph, pattern_indices| {
77+
build_node_callback: &mut |graph, pattern_indices| {
7378
if let Some(index_and_bindings) = pattern_indices.first() {
74-
index_and_bindings.wrap_node(graph, current_node)
79+
cache.get_or_compute(
80+
&mut |graph, index_and_bindings: IndexAndBindings| {
81+
index_and_bindings.wrap_node(graph, current_node)
82+
},
83+
graph,
84+
index_and_bindings,
85+
)
7586
} else {
7687
false_branch
7788
}
@@ -119,6 +130,8 @@ pub fn create_graph_expr_match<'db>(
119130
})
120131
.collect();
121132

133+
let mut cache = Cache::default();
134+
122135
// TODO(lior): add diagnostics if there is an unreachable arm.
123136
let match_node_id = create_node_for_patterns(
124137
CreateNodeParams {
@@ -128,11 +141,17 @@ pub fn create_graph_expr_match<'db>(
128141
.iter()
129142
.map(|(pattern, _)| Some(get_pattern(ctx, *pattern)))
130143
.collect_vec(),
131-
build_node_callback: &|graph, pattern_indices| {
144+
build_node_callback: &mut |graph, pattern_indices| {
132145
// TODO(lior): add diagnostics if pattern_indices is empty (instead of `unwrap`).
133146
let index_and_bindings = pattern_indices.first().unwrap();
134-
let index = index_and_bindings.index();
135-
index_and_bindings.wrap_node(graph, pattern_and_nodes[index].1)
147+
cache.get_or_compute(
148+
&mut |graph, index_and_bindings: IndexAndBindings| {
149+
let index = index_and_bindings.index();
150+
index_and_bindings.wrap_node(graph, pattern_and_nodes[index].1)
151+
},
152+
graph,
153+
index_and_bindings,
154+
)
136155
},
137156
location: matched_expr_location,
138157
},
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
2+
3+
use crate::lower::flow_control::graph::{FlowControlGraphBuilder, NodeId};
4+
5+
/// Implements a simple memoization mechanism to optimize the flow control graph that is created.
6+
///
7+
/// The cache is used before calling `BuildNodeCallback` to avoid creating two nodes that behave
8+
/// identically.
9+
pub struct Cache<Input> {
10+
/// A map from input to the cached result.
11+
cache: UnorderedHashMap<Input, NodeId>,
12+
}
13+
impl<Input: std::hash::Hash + Eq + Clone> Cache<Input> {
14+
/// Calls the callback if this is the first time the input is seen.
15+
/// Returns the previous result, otherwise.
16+
pub fn get_or_compute<'db>(
17+
&mut self,
18+
callback: &mut dyn FnMut(&mut FlowControlGraphBuilder<'db>, Input) -> NodeId,
19+
graph: &mut FlowControlGraphBuilder<'db>,
20+
input: Input,
21+
) -> NodeId {
22+
if let Some(node_id) = self.cache.get(&input) {
23+
return *node_id;
24+
}
25+
26+
let node_id = callback(graph, input.clone());
27+
assert!(!self.cache.contains_key(&input));
28+
self.cache.insert(input, node_id);
29+
node_id
30+
}
31+
}
32+
33+
impl<Input: std::hash::Hash + Eq + Clone> std::default::Default for Cache<Input> {
34+
fn default() -> Self {
35+
Self { cache: Default::default() }
36+
}
37+
}

crates/cairo-lang-lowering/src/lower/flow_control/create_graph/patterns.rs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use itertools::Itertools;
1010
use super::super::graph::{
1111
EnumMatch, FlowControlGraphBuilder, FlowControlNode, FlowControlVar, NodeId,
1212
};
13+
use super::cache::Cache;
1314
use super::filtered_patterns::{Bindings, FilteredPatterns};
1415
use crate::ids::LocationId;
1516
use crate::lower::context::LoweringContext;
@@ -39,7 +40,7 @@ use crate::lower::flow_control::graph::Deconstruct;
3940
/// Finally, the inner pattern-matching function (for `x`) will construct a [EnumMatch] node
4041
/// that leads to the two nodes returned by the callback.
4142
type BuildNodeCallback<'db, 'a> =
42-
&'a dyn Fn(&mut FlowControlGraphBuilder<'db>, FilteredPatterns) -> NodeId;
43+
&'a mut dyn FnMut(&mut FlowControlGraphBuilder<'db>, FilteredPatterns) -> NodeId;
4344

4445
/// A thin wrapper around [semantic::Pattern], where `None` represents the `_` pattern.
4546
type PatternOption<'a, 'db> = Option<&'a semantic::Pattern<'db>>;
@@ -82,11 +83,17 @@ pub fn create_node_for_patterns<'db>(
8283
})
8384
.collect_vec();
8485

85-
// Wrap `build_node_callback` to add the bindings to the patterns.
86-
let build_node_callback = |graph: &mut FlowControlGraphBuilder<'db>,
87-
pattern_indices: FilteredPatterns| {
88-
build_node_callback(graph, pattern_indices.add_bindings(bindings.clone()))
89-
};
86+
let mut cache = Cache::default();
87+
88+
// Wrap `build_node_callback` to add the bindings to the patterns and cache the result.
89+
let mut build_node_callback =
90+
|graph: &mut FlowControlGraphBuilder<'db>, pattern_indices: FilteredPatterns| {
91+
cache.get_or_compute(
92+
build_node_callback,
93+
graph,
94+
pattern_indices.add_bindings(bindings.clone()),
95+
)
96+
};
9097

9198
// If all the patterns are catch-all, we do not need to look into `input_var`.
9299
if patterns.iter().all(|pattern| pattern_is_any(pattern)) {
@@ -99,7 +106,7 @@ pub fn create_node_for_patterns<'db>(
99106
ctx,
100107
graph,
101108
patterns: &patterns,
102-
build_node_callback: &build_node_callback,
109+
build_node_callback: &mut build_node_callback,
103110
location,
104111
};
105112
match long_ty {
@@ -169,7 +176,7 @@ fn create_node_for_enum<'db>(
169176
ctx,
170177
graph,
171178
patterns: &inner_patterns,
172-
build_node_callback: &|graph, pattern_indices_inner| {
179+
build_node_callback: &mut |graph, pattern_indices_inner| {
173180
build_node_callback(graph, pattern_indices_inner.lift(&pattern_indices))
174181
},
175182
location,
@@ -262,14 +269,14 @@ fn create_node_for_tuple_inner<'db>(
262269
ctx,
263270
graph,
264271
patterns: &patterns_on_current_item,
265-
build_node_callback: &|graph, pattern_indices| {
272+
build_node_callback: &mut |graph, pattern_indices| {
266273
// Call `create_node_for_tuple_inner` recursively to handle the rest of the tuple.
267274
create_node_for_tuple_inner(
268275
CreateNodeParams {
269276
ctx,
270277
graph,
271278
patterns: &pattern_indices.indices().map(|idx| patterns[idx]).collect_vec(),
272-
build_node_callback: &|graph, pattern_indices_inner| {
279+
build_node_callback: &mut |graph, pattern_indices_inner| {
273280
build_node_callback(graph, pattern_indices_inner.lift(&pattern_indices))
274281
},
275282
location,

crates/cairo-lang-lowering/src/lower/flow_control/test_data/match

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,14 @@ enum Color {
193193
}
194194

195195
//! > graph
196-
Root: 7
196+
Root: 6
197197
0 ArmExpr { expr: ExprId(1) }
198198
1 ArmExpr { expr: ExprId(2) }
199199
2 ArmExpr { expr: ExprId(3) }
200200
3 EnumMatch { matched_var: v2, variants: (NodeId(1), v8), (NodeId(2), v9), (NodeId(2), v10)}
201-
4 EnumMatch { matched_var: v2, variants: (NodeId(1), v12), (NodeId(2), v13), (NodeId(2), v14)}
202-
5 EnumMatch { matched_var: v1, variants: (NodeId(0), v3), (NodeId(3), v7), (NodeId(4), v11)}
203-
6 Deconstruct { input: v0, outputs: [v1, v2], next: NodeId(5) }
204-
7 EvaluateExpr { expr: ExprId(0), var_id: v0, next: NodeId(6) }
201+
4 EnumMatch { matched_var: v1, variants: (NodeId(0), v3), (NodeId(3), v7), (NodeId(3), v11)}
202+
5 Deconstruct { input: v0, outputs: [v1, v2], next: NodeId(4) }
203+
6 EvaluateExpr { expr: ExprId(0), var_id: v0, next: NodeId(5) }
205204

206205
//! > semantic_diagnostics
207206

0 commit comments

Comments
 (0)