Skip to content

Commit f61aad1

Browse files
committed
flow-control: Add cache node optimization.
commit-id:edbf06d2
1 parent 136da18 commit f61aad1

File tree

4 files changed

+77
-10
lines changed

4 files changed

+77
-10
lines changed

crates/cairo-lang-lowering/src/lower/flow_control/create_graph.rs

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
use cache::Cache;
12
use cairo_lang_semantic::{self as semantic, Condition, PatternId};
23
use cairo_lang_syntax::node::TypedStablePtr;
4+
use filtered_patterns::IndexAndBindings;
35
use itertools::Itertools;
46
use patterns::{CreateNodeParams, create_node_for_patterns, get_pattern};
57

@@ -9,6 +11,7 @@ use super::graph::{
911
};
1012
use crate::lower::context::LoweringContext;
1113

14+
mod cache;
1215
mod filtered_patterns;
1316
mod patterns;
1417

@@ -61,6 +64,8 @@ pub fn create_graph_expr_if<'db>(
6164
let expr_location = ctx.get_location(expr.stable_ptr().untyped());
6265
let expr_var = graph.new_var(expr.ty(), expr_location);
6366

67+
let cache = Cache::default();
68+
6469
let match_node_id = create_node_for_patterns(
6570
CreateNodeParams {
6671
ctx,
@@ -71,7 +76,13 @@ pub fn create_graph_expr_if<'db>(
7176
.collect_vec(),
7277
build_node_callback: &|graph, pattern_indices| {
7378
if let Some(index_and_bindings) = pattern_indices.first() {
74-
index_and_bindings.wrap_node(graph, current_node)
79+
cache.get_or_compute(
80+
&|graph, index_and_bindings: IndexAndBindings| {
81+
index_and_bindings.wrap_node(graph, current_node)
82+
},
83+
graph,
84+
index_and_bindings,
85+
)
7586
} else {
7687
false_branch
7788
}
@@ -119,6 +130,8 @@ pub fn create_graph_expr_match<'db>(
119130
})
120131
.collect();
121132

133+
let cache = Cache::default();
134+
122135
// TODO(lior): add diagnostics if there is an unreachable arm.
123136
let match_node_id = create_node_for_patterns(
124137
CreateNodeParams {
@@ -131,8 +144,14 @@ pub fn create_graph_expr_match<'db>(
131144
build_node_callback: &|graph, pattern_indices| {
132145
// TODO(lior): add diagnostics if pattern_indices is empty (instead of `unwrap`).
133146
let index_and_bindings = pattern_indices.first().unwrap();
134-
let index = index_and_bindings.index();
135-
index_and_bindings.wrap_node(graph, pattern_and_nodes[index].1)
147+
cache.get_or_compute(
148+
&|graph, index_and_bindings: IndexAndBindings| {
149+
let index = index_and_bindings.index();
150+
index_and_bindings.wrap_node(graph, pattern_and_nodes[index].1)
151+
},
152+
graph,
153+
index_and_bindings,
154+
)
136155
},
137156
location: matched_expr_location,
138157
},
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use std::cell::RefCell;
2+
3+
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
4+
5+
use crate::lower::flow_control::graph::{FlowControlGraphBuilder, NodeId};
6+
7+
/// Implements a simple memoization mechanism to optimize the flow control graph that is created.
8+
///
9+
/// The cache is used before calling `BuildNodeCallback` to avoid creating two nodes that behave
10+
/// identically.
11+
pub struct Cache<Input> {
12+
// A map from input to the cached result.
13+
//
14+
// The cache is wrapped in a `RefCell` to allow modifying it without holding a `&mut` to it
15+
// (which would complicate its usage).
16+
cache: RefCell<UnorderedHashMap<Input, NodeId>>,
17+
}
18+
impl<Input: std::hash::Hash + Eq + Clone> Cache<Input> {
19+
/// Calls the callback if this is the first time the input is seen.
20+
/// Returns the previous result, otherwise.
21+
pub fn get_or_compute<'db>(
22+
&self,
23+
callback: &dyn Fn(&mut FlowControlGraphBuilder<'db>, Input) -> NodeId,
24+
graph: &mut FlowControlGraphBuilder<'db>,
25+
input: Input,
26+
) -> NodeId {
27+
if let Some(node_id) = self.cache.borrow().get(&input) {
28+
return *node_id;
29+
}
30+
31+
let node_id = callback(graph, input.clone());
32+
assert!(!self.cache.borrow().contains_key(&input));
33+
self.cache.borrow_mut().insert(input, node_id);
34+
node_id
35+
}
36+
}
37+
38+
impl<Input: std::hash::Hash + Eq + Clone> std::default::Default for Cache<Input> {
39+
fn default() -> Self {
40+
Self { cache: Default::default() }
41+
}
42+
}

crates/cairo-lang-lowering/src/lower/flow_control/create_graph/patterns.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use itertools::Itertools;
1010
use super::super::graph::{
1111
EnumMatch, FlowControlGraphBuilder, FlowControlNode, FlowControlVar, NodeId,
1212
};
13+
use super::cache::Cache;
1314
use super::filtered_patterns::{Bindings, FilteredPatterns};
1415
use crate::ids::LocationId;
1516
use crate::lower::context::LoweringContext;
@@ -82,10 +83,16 @@ pub fn create_node_for_patterns<'db>(
8283
})
8384
.collect_vec();
8485

85-
// Wrap `build_node_callback` to add the bindings to the patterns.
86+
let cache = Cache::default();
87+
88+
// Wrap `build_node_callback` to add the bindings to the patterns and cache the result.
8689
let build_node_callback = |graph: &mut FlowControlGraphBuilder<'db>,
8790
pattern_indices: FilteredPatterns| {
88-
build_node_callback(graph, pattern_indices.add_bindings(bindings.clone()))
91+
cache.get_or_compute(
92+
&build_node_callback,
93+
graph,
94+
pattern_indices.add_bindings(bindings.clone()),
95+
)
8996
};
9097

9198
// If all the patterns are catch-all, we do not need to look into `input_var`.

crates/cairo-lang-lowering/src/lower/flow_control/test_data/match

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,14 @@ enum Color {
193193
}
194194

195195
//! > graph
196-
Root: 7
196+
Root: 6
197197
0 ArmExpr { expr: ExprId(1) }
198198
1 ArmExpr { expr: ExprId(2) }
199199
2 ArmExpr { expr: ExprId(3) }
200200
3 EnumMatch { matched_var: v2, variants: (NodeId(1), v8), (NodeId(2), v9), (NodeId(2), v10)}
201-
4 EnumMatch { matched_var: v2, variants: (NodeId(1), v12), (NodeId(2), v13), (NodeId(2), v14)}
202-
5 EnumMatch { matched_var: v1, variants: (NodeId(0), v3), (NodeId(3), v7), (NodeId(4), v11)}
203-
6 Deconstruct { input: v0, outputs: [v1, v2], next: NodeId(5) }
204-
7 EvaluateExpr { expr: ExprId(0), var_id: v0, next: NodeId(6) }
201+
4 EnumMatch { matched_var: v1, variants: (NodeId(0), v3), (NodeId(3), v7), (NodeId(3), v11)}
202+
5 Deconstruct { input: v0, outputs: [v1, v2], next: NodeId(4) }
203+
6 EvaluateExpr { expr: ExprId(0), var_id: v0, next: NodeId(5) }
205204

206205
//! > semantic_diagnostics
207206

0 commit comments

Comments
 (0)