Skip to content

Commit da21ec8

Browse files
committed
Bumping version to 0.2.4
1 parent 847d68f commit da21ec8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+2078
-304
lines changed

.github/workflows/scallopy.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@ jobs:
2323
max-parallel: 5
2424
matrix:
2525
python-version:
26-
- "3.8"
27-
- "3.9"
2826
- "3.10"
2927

3028
steps:

changelog.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# v0.2.4, Aug 30, 2024
2+
3+
- Rule tags can now be expressions with potential reference to local variables: `rel 1/n::head() = body(n)`
4+
- Allowing for sparse gradient computation inside Scallopy to minimize memory footprint
5+
- Allowing users to specify per-datapoint output mapping inside Scallopy
6+
- Adding destructor syntax so that ADTs can be used in a more idiomatic way
7+
- Unifying the behavior of integer overflow inside Scallop
8+
- Multiple bugs fixed
9+
10+
# v0.2.3, Jun 23, 2024
11+
112
# v0.2.2, Oct 25, 2023
213

314
- Adding `wmc_with_disjunctions` option for provenances that deal with boolean formulas for more accurate probability estimation

core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "scallop-core"
3-
version = "0.2.2"
3+
version = "0.2.4"
44
authors = ["Ziyang Li <[email protected]>"]
55
edition = "2018"
66

core/src/common/foreign_predicate.rs

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,29 +40,29 @@ impl Binding {
4040
}
4141
}
4242

43-
/// The identifier of a foreign predicate in a registry
44-
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
45-
pub struct ForeignPredicateIdentifier {
46-
identifier: String,
47-
types: Box<[ValueType]>,
48-
binding_pattern: BindingPattern,
49-
}
50-
51-
impl std::fmt::Display for ForeignPredicateIdentifier {
52-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53-
f.write_fmt(format_args!(
54-
"pred {}[{}]({})",
55-
self.identifier,
56-
self.binding_pattern,
57-
self
58-
.types
59-
.iter()
60-
.map(|t| format!("{}", t))
61-
.collect::<Vec<_>>()
62-
.join(", ")
63-
))
64-
}
65-
}
43+
// /// The identifier of a foreign predicate in a registry
44+
// #[derive(Clone, Debug, Hash, PartialEq, Eq)]
45+
// pub struct ForeignPredicateIdentifier {
46+
// identifier: String,
47+
// types: Box<[ValueType]>,
48+
// binding_pattern: BindingPattern,
49+
// }
50+
51+
// impl std::fmt::Display for ForeignPredicateIdentifier {
52+
// fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53+
// f.write_fmt(format_args!(
54+
// "pred {}[{}]({})",
55+
// self.identifier,
56+
// self.binding_pattern,
57+
// self
58+
// .types
59+
// .iter()
60+
// .map(|t| format!("{}", t))
61+
// .collect::<Vec<_>>()
62+
// .join(", ")
63+
// ))
64+
// }
65+
// }
6666

6767
/// A binding pattern for a predicate, e.g. bbf
6868
#[derive(Clone, Debug, Hash, PartialEq, Eq)]

core/src/compiler/back/compile.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ impl Program {
1717
// Perform rule level optimizations
1818
for rule in &mut self.rules {
1919
// First propagate equality
20-
optimizations::propagate_equality(rule);
20+
optimizations::propagate_equality(rule, &self.predicate_registry);
2121

2222
// Enter the loop of constant folding/propagation
2323
loop {

core/src/compiler/back/optimizations/equality_propagation.rs

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
use std::collections::*;
22

3+
use crate::common::foreign_predicate::*;
4+
35
use super::super::*;
46

5-
pub fn propagate_equality(rule: &mut Rule) {
7+
pub fn propagate_equality(rule: &mut Rule, foreign_predicate_registry: &ForeignPredicateRegistry) {
68
let mut substitutions = HashMap::<_, Variable>::new();
79
let mut ignore_literals = HashSet::new();
810
let mut cannot_substitute = HashSet::<Variable>::new();
@@ -18,7 +20,7 @@ pub fn propagate_equality(rule: &mut Rule) {
1820
}
1921

2022
// Find all the bounded variables by atom and assign
21-
let bounded = bounded_by_atom_and_assign(rule);
23+
let bounded = bounded_by_atom_and_assign(rule, foreign_predicate_registry);
2224

2325
// Collect all substitutions
2426
for (i, literal) in rule.body_literals().enumerate() {
@@ -136,14 +138,26 @@ pub fn propagate_equality(rule: &mut Rule) {
136138
attributes: rule.attributes.clone(),
137139
head: new_head,
138140
body: Conjunction { args: new_literals },
139-
}
141+
};
140142
}
141143

142-
fn bounded_by_atom_and_assign(rule: &Rule) -> HashSet<Variable> {
144+
fn bounded_by_atom_and_assign(rule: &Rule, foreign_predicate_registry: &ForeignPredicateRegistry) -> HashSet<Variable> {
143145
let mut bounded = rule
144146
.body_literals()
145147
.flat_map(|l| match l {
146-
Literal::Atom(a) => a.variable_args().cloned().collect::<Vec<_>>(),
148+
Literal::Atom(atom) => {
149+
if let Some(fp) = foreign_predicate_registry.get(&atom.predicate) {
150+
// If atom is on foreign predicate, only the variables that are free will be bounded
151+
atom.args[fp.num_bounded()..fp.arity()]
152+
.iter()
153+
.filter_map(|term| term.as_variable())
154+
.cloned()
155+
.collect::<Vec<_>>()
156+
} else {
157+
// If atom is on a normal relation, all the variables will be bounded
158+
atom.variable_args().cloned().collect::<Vec<_>>()
159+
}
160+
}
147161
_ => vec![],
148162
})
149163
.collect::<HashSet<_>>();

core/src/compiler/front/analysis.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pub struct Analysis {
1818
pub constant_decl_analysis: ConstantDeclAnalysis,
1919
pub adt_analysis: AlgebraicDataTypeAnalysis,
2020
pub head_relation_analysis: HeadRelationAnalysis,
21+
pub tagged_rule_analysis: TaggedRuleAnalysis,
2122
pub type_inference: TypeInference,
2223
pub boundness_analysis: BoundnessAnalysis,
2324
pub demand_attr_analysis: DemandAttributeAnalysis,
@@ -41,6 +42,7 @@ impl Analysis {
4142
constant_decl_analysis: ConstantDeclAnalysis::new(),
4243
adt_analysis: AlgebraicDataTypeAnalysis::new(),
4344
head_relation_analysis: HeadRelationAnalysis::new(predicate_registry),
45+
tagged_rule_analysis: TaggedRuleAnalysis::new(),
4446
type_inference: TypeInference::new(function_registry, predicate_registry, aggregate_registry),
4547
boundness_analysis: BoundnessAnalysis::new(predicate_registry),
4648
demand_attr_analysis: DemandAttributeAnalysis::new(),
@@ -78,12 +80,15 @@ impl Analysis {
7880
items.walk(&mut analyzers);
7981
}
8082

81-
pub fn post_analysis(&mut self) {
83+
pub fn post_analysis(&mut self, foreign_predicate_registry: &mut ForeignPredicateRegistry) {
8284
self.head_relation_analysis.compute_errors();
8385
self.type_inference.check_query_predicates();
8486
self.type_inference.infer_types();
8587
self.demand_attr_analysis.check_arity(&self.type_inference);
8688
self.boundness_analysis.check_boundness(&self.demand_attr_analysis);
89+
self
90+
.tagged_rule_analysis
91+
.register_predicates(&self.type_inference, foreign_predicate_registry);
8792
}
8893

8994
pub fn dump_errors(&mut self, error_ctx: &mut FrontCompileError) {
@@ -98,5 +103,6 @@ impl Analysis {
98103
error_ctx.extend(&mut self.type_inference.errors);
99104
error_ctx.extend(&mut self.boundness_analysis.errors);
100105
error_ctx.extend(&mut self.demand_attr_analysis.errors);
106+
error_ctx.extend(&mut self.tagged_rule_analysis.errors);
101107
}
102108
}

core/src/compiler/front/analyzers/constant_decl.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ impl NodeVisitor<FactDecl> for ConstantDeclAnalysis {
216216
for v in vars {
217217
if self.variables.contains_key(v.variable_name()) {
218218
self.variable_use.insert(v.location().clone(), v.name().to_string());
219-
} else {
219+
} else if !fact_decl.atom().iter_args().any(|arg| arg.is_destruct()) {
220220
self.errors.push(ConstantDeclError::UnknownConstantVariable {
221221
name: v.name().to_string(),
222222
loc: v.location().clone(),

core/src/compiler/front/analyzers/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ pub mod input_files;
1010
pub mod invalid_constant;
1111
pub mod invalid_wildcard;
1212
pub mod output_files;
13+
pub mod tagged_rule;
1314
pub mod type_inference;
1415

1516
pub use aggregation::AggregationAnalysis;
@@ -24,6 +25,7 @@ pub use input_files::InputFilesAnalysis;
2425
pub use invalid_constant::InvalidConstantAnalyzer;
2526
pub use invalid_wildcard::InvalidWildcardAnalyzer;
2627
pub use output_files::OutputFilesAnalysis;
28+
pub use tagged_rule::TaggedRuleAnalysis;
2729
pub use type_inference::TypeInference;
2830

2931
pub mod errors {
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
use lazy_static::lazy_static;
2+
use std::collections::*;
3+
4+
use crate::common::expr;
5+
use crate::common::foreign_predicate::*;
6+
use crate::common::input_tag::*;
7+
use crate::common::tuple::*;
8+
use crate::common::unary_op;
9+
use crate::common::value::*;
10+
use crate::common::value_type::*;
11+
12+
use crate::compiler::front::*;
13+
use crate::runtime::env::*;
14+
15+
lazy_static! {
16+
pub static ref TAG_TYPE: Vec<ValueType> = {
17+
use ValueType::*;
18+
vec![F64, F32, Bool]
19+
};
20+
}
21+
22+
#[derive(Clone, Debug)]
23+
pub struct TaggedRuleAnalysis {
24+
pub to_add_tag_predicates: HashMap<ast::NodeLocation, ToAddTagPredicate>,
25+
pub errors: Vec<FrontCompileErrorMessage>,
26+
}
27+
28+
impl TaggedRuleAnalysis {
29+
pub fn new() -> Self {
30+
Self {
31+
to_add_tag_predicates: HashMap::new(),
32+
errors: Vec::new(),
33+
}
34+
}
35+
36+
pub fn add_tag_predicate(
37+
&mut self,
38+
rule_id: ast::NodeLocation,
39+
name: String,
40+
arg_name: String,
41+
tag_loc: ast::NodeLocation,
42+
) {
43+
let pred = ToAddTagPredicate::new(name, arg_name, tag_loc);
44+
self.to_add_tag_predicates.insert(rule_id, pred);
45+
}
46+
47+
pub fn register_predicates(
48+
&mut self,
49+
type_inference: &super::TypeInference,
50+
foreign_predicate_registry: &mut ForeignPredicateRegistry,
51+
) {
52+
for (rule_id, tag_predicate) in self.to_add_tag_predicates.drain() {
53+
if let Some(rule_variable_type) = type_inference.rule_variable_type.get(&rule_id) {
54+
if let Some(var_ty) = rule_variable_type.get(&tag_predicate.arg_name) {
55+
match get_target_tag_type(var_ty, &tag_predicate.tag_loc) {
56+
Ok(target_tag_ty) => {
57+
// This means that we have an okay tag that is type checked
58+
// Create a foreign predicate and register it
59+
let fp = TagPredicate::new(tag_predicate.name.clone(), target_tag_ty);
60+
if let Err(err) = foreign_predicate_registry.register(fp) {
61+
self.errors.push(FrontCompileErrorMessage::error().msg(err.to_string()));
62+
}
63+
}
64+
Err(err) => {
65+
self.errors.push(err);
66+
}
67+
}
68+
}
69+
}
70+
}
71+
}
72+
}
73+
74+
fn get_target_tag_type(
75+
var_ty: &analyzers::type_inference::TypeSet,
76+
loc: &ast::NodeLocation,
77+
) -> Result<ValueType, FrontCompileErrorMessage> {
78+
// Top priority: if var_ty is a base type, directly check if it is among some expected type
79+
if let Some(base_ty) = var_ty.get_base_type() {
80+
if TAG_TYPE.contains(&base_ty) {
81+
return Ok(base_ty);
82+
}
83+
}
84+
85+
// Then we check if the value can be casted into certain types
86+
for tag_ty in TAG_TYPE.iter() {
87+
if var_ty.can_type_cast(tag_ty) {
88+
return Ok(var_ty.to_default_value_type());
89+
}
90+
}
91+
92+
// If not, then
93+
return Err(
94+
FrontCompileErrorMessage::error()
95+
.msg(format!(
96+
"A value of type `{var_ty}` cannot be casted into a dynamic tag"
97+
))
98+
.src(loc.clone()),
99+
);
100+
}
101+
102+
/// The information of a helper tag predicate
103+
///
104+
/// Suppose we have a rule
105+
/// ``` ignore
106+
/// rel 1/p :: head() = body(p)
107+
/// ```
108+
///
109+
/// This rule will be transformed into
110+
/// ``` ignore
111+
/// rel head() = body(p) and tag#head#1#var == 1 / p and tag#head#1(tag#head#1#var)
112+
/// ```
113+
#[derive(Clone, Debug)]
114+
pub struct ToAddTagPredicate {
115+
/// The name of the predicate
116+
pub name: String,
117+
118+
/// The main tag expression
119+
pub arg_name: String,
120+
121+
/// Tag location
122+
pub tag_loc: ast::NodeLocation,
123+
}
124+
125+
impl ToAddTagPredicate {
126+
pub fn new(name: String, arg_name: String, tag_loc: ast::NodeLocation) -> Self {
127+
Self {
128+
name,
129+
arg_name,
130+
tag_loc,
131+
}
132+
}
133+
}
134+
135+
/// An actual predicate
136+
#[derive(Clone, Debug)]
137+
pub struct TagPredicate {
138+
/// The name of he predicate
139+
pub name: String,
140+
141+
/// args
142+
pub arg_ty: ValueType,
143+
}
144+
145+
impl TagPredicate {
146+
pub fn new(name: String, arg_ty: ValueType) -> Self {
147+
Self { name, arg_ty }
148+
}
149+
}
150+
151+
impl ForeignPredicate for TagPredicate {
152+
fn name(&self) -> String {
153+
self.name.clone()
154+
}
155+
156+
fn arity(&self) -> usize {
157+
1
158+
}
159+
160+
fn argument_type(&self, i: usize) -> ValueType {
161+
assert_eq!(i, 0);
162+
self.arg_ty.clone()
163+
}
164+
165+
fn num_bounded(&self) -> usize {
166+
1
167+
}
168+
169+
fn evaluate_with_env(&self, env: &RuntimeEnvironment, bounded: &[Value]) -> Vec<(DynamicInputTag, Vec<Value>)> {
170+
// Result tuple
171+
let tup = vec![];
172+
173+
// Create a type cast expression and evaluate it on the given values
174+
let tuple = Tuple::from_values(bounded.iter().cloned());
175+
let cast_expr = expr::Expr::unary(unary_op::UnaryOp::TypeCast(ValueType::F64), expr::Expr::access(0));
176+
let maybe_computed_tag = env.eval(&cast_expr, &tuple);
177+
178+
// Return the value
179+
if let Some(Tuple::Value(Value::F64(f))) = maybe_computed_tag {
180+
vec![(DynamicInputTag::Float(f), tup)]
181+
} else {
182+
vec![]
183+
}
184+
}
185+
}

0 commit comments

Comments
 (0)