Skip to content

Commit 2c50ca9

Browse files
committed
Bumping version to 0.2.0
1 parent a3f7b67 commit 2c50ca9

File tree

466 files changed

+11856
-17747
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

466 files changed

+11856
-17747
lines changed

.github/workflows/rust.yml

Lines changed: 0 additions & 24 deletions
This file was deleted.

.github/workflows/scallopy-torch.yml

Lines changed: 0 additions & 60 deletions
This file was deleted.

Cargo.toml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
[workspace]
22
members = [
33
"core",
4+
5+
# Extra packages
46
"etc/codegen",
57
"etc/sclc",
68
"etc/scli",
79
"etc/sclrepl",
810
"etc/scallopy",
911
# "etc/scallop-node",
1012
"etc/scallop-wasm",
13+
14+
# Additional dependencies
15+
"lib/astnode-derive",
16+
"lib/parse_relative_duration",
1117
"lib/sdd",
12-
"lib/rsat",
13-
"lib/ram",
18+
19+
# Laboratory
20+
# "lab/rsat",
21+
# "lab/dyn-tensor-registry",
22+
# "lab/ram-egg",
23+
# "lab/ast-derive",
24+
# "lab/type-inference",
25+
# "lab/visitor",
1426
]
1527

1628
default-members = [

core/Cargo.toml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,12 @@ petgraph = "0.6"
2222
csv = "1.1"
2323
sprs = "0.11"
2424
chrono = { version = "0.4", features = ["serde"] }
25+
chronoutil = { git = "https://github.com/Liby99/chronoutil.git" }
2526
dateparser = "0.1.6"
26-
parse_duration = "2.1.1"
2727
dyn-clone = "1.0.10"
2828
lazy_static = "1.4"
2929
serde = { version = "1.0", features = ["derive"] }
30+
parse_relative_duration = { path = "../lib/parse_relative_duration" }
3031
rand = { version = "0.8", features = ["std_rng", "small_rng", "alloc"] }
32+
astnode-derive = { path = "../lib/astnode-derive" }
3133
sdd = { path = "../lib/sdd" }
32-
33-
# Optional ones
34-
tch = { version = "0.13.0", optional = true }
35-
36-
[features]
37-
torch-tensor = ["dep:tch"]
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
use std::collections::*;
2+
3+
use super::entity::*;
4+
use super::value::*;
5+
use super::value_type::*;
6+
7+
use crate::compiler::front::*;
8+
9+
#[derive(Debug, Clone)]
10+
pub struct ADTVariant {
11+
pub relation_name: String,
12+
pub arg_types: Vec<ValueType>,
13+
}
14+
15+
#[derive(Debug, Clone)]
16+
pub struct ADTVariantRegistry {
17+
registry: HashMap<String, ADTVariant>,
18+
}
19+
20+
impl ADTVariantRegistry {
21+
pub fn new() -> Self {
22+
Self {
23+
registry: HashMap::new(),
24+
}
25+
}
26+
27+
pub fn add(&mut self, variant_name: String, relation_name: String, arg_types: Vec<ValueType>) {
28+
let variant = ADTVariant {
29+
relation_name,
30+
arg_types,
31+
};
32+
self.registry.insert(variant_name, variant);
33+
}
34+
35+
pub fn iter(&self) -> std::collections::hash_map::Iter<String, ADTVariant> {
36+
self.registry.iter()
37+
}
38+
39+
pub fn parse(&self, s: &str) -> Result<ADTParseResult, ADTEntityError> {
40+
// First parse an entity from string
41+
let entity = parser::str_to_entity(s).map_err(ADTEntityError::Parsing)?;
42+
43+
// Completely parse the entity:
44+
// - check that there is no variable or expression involved
45+
// - check that all the mentioned ADT variants exist
46+
// - check all the constant type
47+
// - create intermediate ADTs
48+
// - generate a final ADT as the entity value
49+
let mut facts = Vec::new();
50+
let entity = self.parse_entity(&entity, &ValueType::Entity, &mut facts)?;
51+
52+
// Return the final result
53+
Ok(ADTParseResult { entity, facts })
54+
}
55+
56+
fn parse_entity(
57+
&self,
58+
entity: &ast::Entity,
59+
ty: &ValueType,
60+
facts: &mut Vec<(String, Value, Vec<Value>)>,
61+
) -> Result<Value, ADTEntityError> {
62+
match &entity {
63+
ast::Entity::Expr(e) => match e {
64+
ast::Expr::Constant(c) => match (&c, ty) {
65+
(ast::Constant::Integer(i), ValueType::I8) => Ok(Value::I8(i.int().clone() as i8)),
66+
(ast::Constant::Integer(i), ValueType::I16) => Ok(Value::I16(i.int().clone() as i16)),
67+
(ast::Constant::Integer(i), ValueType::I32) => Ok(Value::I32(i.int().clone() as i32)),
68+
(ast::Constant::Integer(i), ValueType::I64) => Ok(Value::I64(i.int().clone() as i64)),
69+
(ast::Constant::Integer(i), ValueType::I128) => Ok(Value::I128(i.int().clone() as i128)),
70+
(ast::Constant::Integer(i), ValueType::ISize) => Ok(Value::ISize(i.int().clone() as isize)),
71+
(ast::Constant::Integer(i), ValueType::U8) => Ok(Value::U8(i.int().clone() as u8)),
72+
(ast::Constant::Integer(i), ValueType::U16) => Ok(Value::U16(i.int().clone() as u16)),
73+
(ast::Constant::Integer(i), ValueType::U32) => Ok(Value::U32(i.int().clone() as u32)),
74+
(ast::Constant::Integer(i), ValueType::U64) => Ok(Value::U64(i.int().clone() as u64)),
75+
(ast::Constant::Integer(i), ValueType::U128) => Ok(Value::U128(i.int().clone() as u128)),
76+
(ast::Constant::Integer(i), ValueType::USize) => Ok(Value::USize(i.int().clone() as usize)),
77+
(ast::Constant::Integer(i), ValueType::F32) => Ok(Value::F32(i.int().clone() as f32)),
78+
(ast::Constant::Integer(i), ValueType::F64) => Ok(Value::F64(i.int().clone() as f64)),
79+
(ast::Constant::Float(f), ValueType::F32) => Ok(Value::F32(f.float().clone() as f32)),
80+
(ast::Constant::Float(f), ValueType::F64) => Ok(Value::F64(f.float().clone() as f64)),
81+
(ast::Constant::Boolean(b), ValueType::Bool) => Ok(Value::Bool(b.value().clone())),
82+
(ast::Constant::Char(c), ValueType::Char) => Ok(Value::Char(c.parse_char())),
83+
(ast::Constant::String(s), ValueType::String) => Ok(Value::String(s.string().clone())),
84+
_ => Err(ADTEntityError::CannotUnifyType),
85+
},
86+
_ => Err(ADTEntityError::InvalidExpr),
87+
},
88+
ast::Entity::Object(o) => {
89+
let variant_name = o.functor().name();
90+
if let Some(variant) = self.registry.get(variant_name) {
91+
let expected_arity = variant.arg_types.len() - 1;
92+
let actual_arity = o.args().len();
93+
if expected_arity == actual_arity {
94+
// Compute the arguments
95+
let parsed_args = o
96+
.args()
97+
.iter()
98+
.zip(variant.arg_types.iter().skip(1))
99+
.map(|(arg, arg_ty)| self.parse_entity(arg, arg_ty, facts))
100+
.collect::<Result<Vec<_>, _>>()?;
101+
102+
// Aggregate them into a hash
103+
let entity = Value::Entity(encode_entity(variant_name, parsed_args.iter()));
104+
105+
// Create a new fact to insert
106+
let fact = (variant.relation_name.clone(), entity.clone(), parsed_args);
107+
facts.push(fact);
108+
109+
// Return the entity as result
110+
Ok(entity)
111+
} else {
112+
Err(ADTEntityError::ArityMismatch {
113+
variant: variant_name.to_string(),
114+
expected: expected_arity,
115+
actual: actual_arity,
116+
})
117+
}
118+
} else {
119+
Err(ADTEntityError::UnknownVariant(variant_name.to_string()))
120+
}
121+
}
122+
}
123+
}
124+
}
125+
126+
#[derive(Clone, Debug)]
127+
pub struct ADTParseResult {
128+
pub entity: Value,
129+
pub facts: Vec<(String, Value, Vec<Value>)>,
130+
}
131+
132+
#[derive(Clone, Debug)]
133+
pub enum ADTEntityError {
134+
Parsing(parser::ParserError),
135+
InvalidExpr,
136+
CannotUnifyType,
137+
UnknownVariant(String),
138+
ArityMismatch {
139+
variant: String,
140+
expected: usize,
141+
actual: usize,
142+
},
143+
}
144+
145+
impl std::fmt::Display for ADTEntityError {
146+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147+
match self {
148+
Self::Parsing(p) => p.fmt(f),
149+
Self::InvalidExpr => f.write_str("Invalid Expression"),
150+
Self::CannotUnifyType => f.write_str("Cannot unify type"),
151+
Self::UnknownVariant(v) => f.write_fmt(format_args!("Unknown variant `{}`", v)),
152+
Self::ArityMismatch {
153+
variant,
154+
expected,
155+
actual,
156+
} => f.write_fmt(format_args!(
157+
"Arity mismatch for variant `{}`, expected {}, found {}",
158+
variant, expected, actual
159+
)),
160+
}
161+
}
162+
}

core/src/common/aggregate_op.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use super::value_type::*;
55
/// The aggregate operators for low level representation
66
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
77
pub enum AggregateOp {
8-
Count,
8+
Count { discrete: bool },
99
Sum(ValueType),
1010
Prod(ValueType),
1111
Min,
@@ -20,7 +20,7 @@ pub enum AggregateOp {
2020
impl std::fmt::Display for AggregateOp {
2121
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2222
match self {
23-
Self::Count => f.write_str("count"),
23+
Self::Count { discrete } => if *discrete { f.write_str("discrete_count") } else { f.write_str("count") },
2424
Self::Sum(t) => f.write_fmt(format_args!("sum<{}>", t)),
2525
Self::Prod(t) => f.write_fmt(format_args!("prod<{}>", t)),
2626
Self::Min => f.write_str("min"),
@@ -35,6 +35,14 @@ impl std::fmt::Display for AggregateOp {
3535
}
3636

3737
impl AggregateOp {
38+
pub fn count() -> Self {
39+
Self::Count { discrete: false }
40+
}
41+
42+
pub fn discrete_count() -> Self {
43+
Self::Count { discrete: true }
44+
}
45+
3846
pub fn min(has_arg: bool) -> Self {
3947
if has_arg {
4048
Self::Argmin

0 commit comments

Comments
 (0)