Skip to content

Commit 14be23c

Browse files
jrodal98facebook-github-bot
authored andcommitted
dag added (#677)
Summary: Pull Request resolved: #677 ## What * Create the dag struct with `next_column` implemented * add Clone + Copy to ColumnMetadata ## Why * So that we can iterate through the columns in order ## Future improvements * Implement Iter trait so that we can iterate on the dag * Support iterating / getting the next column level so that you can parallelize computation of some columns Reviewed By: gorel Differential Revision: D34772282 fbshipit-source-id: d50a05bd21c57288bdddec218bad4ffa625a50c1
1 parent 32482d1 commit 14be23c

File tree

5 files changed

+152
-5
lines changed

5 files changed

+152
-5
lines changed

fbpcs/kodiak/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ edition = "2021"
77
[dependencies]
88
derive_more = "0.99.3"
99
log = { version = "0.4.14", features = ["kv_unstable", "kv_unstable_std"] }
10+
petgraph = "0.6"

fbpcs/kodiak/src/column_metadata.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
use crate::mpc_metric_dtype::MPCMetricDType;
99
use crate::row::Row;
1010

11-
pub trait ColumnMetadata: std::cmp::Eq + std::hash::Hash + Sized {
11+
pub trait ColumnMetadata: std::cmp::Eq + std::hash::Hash + Sized + Clone + Copy {
1212
/// Used to look up a human-readable name for this metric.
1313
/// Should be known at compile time, so &'static is fine.
1414
fn name(&self) -> &'static str;
@@ -34,7 +34,7 @@ macro_rules! column_metadata {
3434
$($variant:ident -> [$($deps:ident),*]),*,
3535
}) => {
3636

37-
#[derive(Debug, PartialEq, Eq, std::hash::Hash)]
37+
#[derive(Copy, Clone, Debug, PartialEq, Eq, std::hash::Hash)]
3838
pub enum $name {
3939
$($variant),*
4040
}

fbpcs/kodiak/src/dag.rs

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
use crate::column_metadata::ColumnMetadata;
9+
use crate::mpc_view::MPCView;
10+
use petgraph::graph::DiGraph;
11+
12+
pub struct Dag<T: ColumnMetadata> {
13+
graph: DiGraph<T, usize>,
14+
sorted_columns: Vec<T>,
15+
current_index: usize,
16+
}
17+
18+
impl<T: ColumnMetadata> Dag<T> {
19+
pub fn next_column(&mut self) -> Option<&T> {
20+
let node = self.sorted_columns.get(self.current_index);
21+
self.current_index += 1;
22+
node
23+
}
24+
25+
pub fn next_columns(&mut self) -> Option<Vec<&T>> {
26+
unimplemented!("Support for per-level operation not implemented")
27+
}
28+
29+
pub fn reset(&mut self) {
30+
self.current_index = 0;
31+
}
32+
33+
pub fn from_mpc_view(mpc_view: MPCView<T>) -> Self {
34+
let graph = Self::build_graph(
35+
&mpc_view.input_columns,
36+
&mpc_view.helper_columns,
37+
&mpc_view.metrics,
38+
);
39+
40+
let sorted_nodes = Self::toposort_nodes(&graph);
41+
42+
Self {
43+
graph,
44+
sorted_columns: sorted_nodes,
45+
current_index: 0,
46+
}
47+
}
48+
49+
fn build_graph(input_columns: &[T], helper_columns: &[T], metrics: &[T]) -> DiGraph<T, usize> {
50+
let mut graph = DiGraph::new();
51+
52+
// 1. add the metric column data to the graph
53+
// 2. insert a mapping from node -> graph index
54+
// 3. repeat for every column
55+
let node_to_index = input_columns
56+
.iter()
57+
.chain(helper_columns.iter().chain(metrics.iter()))
58+
.fold(std::collections::HashMap::new(), |mut acc, &node| {
59+
let i = graph.add_node(node.clone());
60+
acc.insert(node, i);
61+
acc
62+
});
63+
64+
for to_node in input_columns
65+
.iter()
66+
.chain(helper_columns.iter().chain(metrics.iter()))
67+
{
68+
let to_index = node_to_index
69+
.get(to_node)
70+
.unwrap_or_else(|| panic!("Column {} was not found in the graph", to_node.name()));
71+
for from_node in to_node.dependencies().iter() {
72+
let from_index = node_to_index.get(from_node).unwrap_or_else(|| {
73+
panic!("Column {} was not found in the graph", from_node.name())
74+
});
75+
// 0 is the weight
76+
graph.add_edge(*from_index, *to_index, 0);
77+
}
78+
}
79+
80+
graph
81+
}
82+
83+
fn toposort_nodes(graph: &DiGraph<T, usize>) -> Vec<T> {
84+
match petgraph::algo::toposort(&graph, None) {
85+
Ok(order) => order
86+
.into_iter()
87+
.map(|node_index| *graph.node_weight(node_index).unwrap())
88+
.collect(),
89+
Err(_e) => panic!("Cycle detected in graph"),
90+
}
91+
}
92+
}
93+
94+
#[cfg(test)]
95+
mod tests {
96+
use crate::column_metadata;
97+
use crate::column_metadata::ColumnMetadata;
98+
use crate::dag::Dag;
99+
use crate::mpc_metric_dtype::MPCMetricDType;
100+
use crate::mpc_view::MPCView;
101+
use crate::row::Row;
102+
103+
column_metadata! {
104+
TestEnum {
105+
Variant1 -> [],
106+
Variant2 -> [Variant1],
107+
Variant3 -> [Variant1],
108+
Variant4 -> [Variant2, Variant3],
109+
}
110+
}
111+
112+
impl TestEnum {
113+
fn from_row(&self, _r: &Row<Self>) -> MPCMetricDType {
114+
panic!("Undefined for test");
115+
}
116+
fn aggregate<I: Iterator<Item = Row<Self>>>(&self, _rows: I) -> MPCMetricDType {
117+
panic!("Undefined for test");
118+
}
119+
}
120+
121+
fn get_mpc_view() -> MPCView<TestEnum> {
122+
MPCView::new(
123+
vec![TestEnum::Variant1],
124+
vec![TestEnum::Variant2, TestEnum::Variant3],
125+
vec![TestEnum::Variant4],
126+
vec![],
127+
)
128+
}
129+
130+
#[test]
131+
fn dag_next_node() {
132+
let mpc_view = get_mpc_view();
133+
let mut dag = Dag::from_mpc_view(mpc_view);
134+
assert_eq!(dag.next_column(), Some(&TestEnum::Variant1));
135+
assert_eq!(dag.next_column(), Some(&TestEnum::Variant3));
136+
assert_eq!(dag.next_column(), Some(&TestEnum::Variant2));
137+
assert_eq!(dag.next_column(), Some(&TestEnum::Variant4));
138+
assert_eq!(dag.next_column(), None);
139+
dag.reset();
140+
assert_eq!(dag.next_column(), Some(&TestEnum::Variant1));
141+
assert_eq!(dag.next_column(), Some(&TestEnum::Variant3));
142+
assert_eq!(dag.next_column(), Some(&TestEnum::Variant2));
143+
assert_eq!(dag.next_column(), Some(&TestEnum::Variant4));
144+
}
145+
}

fbpcs/kodiak/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
*/
77

88
pub mod column_metadata;
9+
pub mod dag;
910
pub mod execution_config;
1011
pub mod input_reader;
1112
pub mod metric_config;

fbpcs/kodiak/src/mpc_view.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
use crate::column_metadata::ColumnMetadata;
99

1010
pub struct MPCView<T: ColumnMetadata> {
11-
input_columns: Vec<T>,
12-
helper_columns: Vec<T>,
13-
metrics: Vec<T>,
11+
pub input_columns: Vec<T>,
12+
pub helper_columns: Vec<T>,
13+
pub metrics: Vec<T>,
1414
grouping_sets: Vec<Vec<T>>,
1515
}
1616

0 commit comments

Comments
 (0)