diff --git a/aw-query/src/functions.rs b/aw-query/src/functions.rs index 31245560..79bc6b47 100644 --- a/aw-query/src/functions.rs +++ b/aw-query/src/functions.rs @@ -7,97 +7,97 @@ pub type QueryFn = fn(args: Vec, env: &VarEnv, ds: &Datastore) -> Result; pub fn fill_env(env: &mut VarEnv) { - env.insert( + env.declare_static( "print".to_string(), DataType::Function("print".to_string(), qfunctions::print), ); - env.insert( + env.declare_static( "query_bucket".to_string(), DataType::Function("query_bucket".to_string(), qfunctions::query_bucket), ); - env.insert( + env.declare_static( "query_bucket_names".to_string(), DataType::Function( "query_bucket_names".to_string(), qfunctions::query_bucket_names, ), ); - env.insert( + env.declare_static( "sort_by_duration".to_string(), DataType::Function("sort_by_duration".to_string(), qfunctions::sort_by_duration), ); - env.insert( + env.declare_static( "sort_by_timestamp".to_string(), DataType::Function( "sort_by_timestamp".to_string(), qfunctions::sort_by_timestamp, ), ); - env.insert( + env.declare_static( "sum_durations".to_string(), DataType::Function("sum_durations".to_string(), qfunctions::sum_durations), ); - env.insert( + env.declare_static( "limit_events".to_string(), DataType::Function("limit_events".to_string(), qfunctions::limit_events), ); - env.insert( + env.declare_static( "contains".to_string(), DataType::Function("contains".to_string(), qfunctions::contains), ); - env.insert( + env.declare_static( "flood".to_string(), DataType::Function("flood".to_string(), qfunctions::flood), ); - env.insert( + env.declare_static( "find_bucket".to_string(), DataType::Function("find_bucket".to_string(), qfunctions::find_bucket), ); - env.insert( + env.declare_static( "merge_events_by_keys".to_string(), DataType::Function( "merge_events_by_keys".to_string(), qfunctions::merge_events_by_keys, ), ); - env.insert( + env.declare_static( "chunk_events_by_key".to_string(), DataType::Function( "chunk_events_by_key".to_string(), qfunctions::chunk_events_by_key, ), ); - env.insert( + env.declare_static( "filter_keyvals".to_string(), DataType::Function("filter_keyvals".to_string(), qfunctions::filter_keyvals), ); - env.insert( + env.declare_static( "filter_keyvals_regex".to_string(), DataType::Function( "filter_keyvals_regex".to_string(), qfunctions::filter_keyvals_regex, ), ); - env.insert( + env.declare_static( "filter_period_intersect".to_string(), DataType::Function( "filter_period_intersect".to_string(), qfunctions::filter_period_intersect, ), ); - env.insert( + env.declare_static( "split_url_events".to_string(), DataType::Function("split_url_events".to_string(), qfunctions::split_url_events), ); - env.insert( + env.declare_static( "concat".to_string(), DataType::Function("concat".to_string(), qfunctions::concat), ); - env.insert( + env.declare_static( "categorize".to_string(), DataType::Function("categorize".into(), qfunctions::categorize), ); - env.insert( + env.declare_static( "tag".to_string(), DataType::Function("tag".into(), qfunctions::tag), ); @@ -521,7 +521,7 @@ mod validate { } pub fn get_timeinterval(env: &VarEnv) -> Result { - let interval_str = match env.get("TIMEINTERVAL") { + let interval_str = match env.deprecated_get("TIMEINTERVAL") { Some(data_ti) => match data_ti { DataType::String(ti_str) => ti_str, _ => { diff --git a/aw-query/src/interpret.rs b/aw-query/src/interpret.rs index 80f4aa97..8e0f8bb5 100644 --- a/aw-query/src/interpret.rs +++ b/aw-query/src/interpret.rs @@ -1,43 +1,27 @@ use std::collections::HashMap; -use crate::functions; - use aw_datastore::Datastore; -use aw_models::TimeInterval; use crate::ast::*; use crate::DataType; use crate::QueryError; - -pub type VarEnv = HashMap; - -fn init_env(ti: &TimeInterval) -> VarEnv { - let mut env = HashMap::new(); - env.insert("TIMEINTERVAL".to_string(), DataType::String(ti.to_string())); - functions::fill_env(&mut env); - env -} +use crate::VarEnv; pub fn interpret_prog( p: Program, - ti: &TimeInterval, + env: &mut VarEnv, ds: &Datastore, ) -> Result { - let mut env = init_env(ti); for expr in p.stmts { - interpret_expr(&mut env, ds, expr)?; + interpret_expr(env, ds, expr)?; } - match env.remove("RETURN") { + match env.take("RETURN") { Some(ret) => Ok(ret), None => Err(QueryError::EmptyQuery()), } } -fn interpret_expr( - env: &mut HashMap, - ds: &Datastore, - expr: Expr, -) -> Result { +fn interpret_expr(env: &mut VarEnv, ds: &Datastore, expr: Expr) -> Result { use crate::ast::Expr_::*; match expr.node { Add(a, b) => { @@ -184,9 +168,8 @@ fn interpret_expr( env.insert(var, val); Ok(DataType::None()) } - // FIXME: avoid clone, it's slow - Var(var) => match env.get(&var) { - Some(v) => Ok(v.clone()), + Var(var) => match env.take(&var) { + Some(v) => Ok(v), None => Err(QueryError::VariableNotDefined(var.to_string())), }, Bool(lit) => Ok(DataType::Bool(lit)), @@ -195,6 +178,7 @@ fn interpret_expr( Return(e) => { let val = interpret_expr(env, ds, *e)?; // TODO: Once RETURN is deprecated we can fix this + env.declare("RETURN".to_string()); env.insert("RETURN".to_string(), val); Ok(DataType::None()) } @@ -215,7 +199,7 @@ fn interpret_expr( DataType::List(l) => l, _ => unreachable!(), }; - let var = match env.get(&fname[..]) { + let var = match env.take(&fname[..]) { Some(v) => v, None => return Err(QueryError::VariableNotDefined(fname.clone())), }; diff --git a/aw-query/src/lib.rs b/aw-query/src/lib.rs index c9956815..65f63bac 100644 --- a/aw-query/src/lib.rs +++ b/aw-query/src/lib.rs @@ -21,9 +21,11 @@ mod lexer; unused_braces )] mod parser; +mod preprocess; +mod varenv; pub use crate::datatype::DataType; -pub use crate::interpret::VarEnv; +pub use crate::varenv::VarEnv; // TODO: add line numbers to errors // (works during lexing, but not during parsing I believe) @@ -50,6 +52,13 @@ impl fmt::Display for QueryError { } } +fn init_env(ti: &TimeInterval) -> VarEnv { + let mut env = VarEnv::new(); + env.declare_static("TIMEINTERVAL".to_string(), DataType::String(ti.to_string())); + functions::fill_env(&mut env); + env +} + pub fn query(code: &str, ti: &TimeInterval, ds: &Datastore) -> Result { let lexer = lexer::Lexer::new(code); let program = match parser::parse(lexer) { @@ -60,5 +69,7 @@ pub fn query(code: &str, ti: &TimeInterval, ds: &Datastore) -> Result Result<(), QueryError> { + for expr in &p.stmts { + preprocess_expr(env, ds, expr)?; + } + Ok(()) +} + +fn preprocess_expr(env: &mut VarEnv, ds: &Datastore, expr: &Expr) -> Result<(), QueryError> { + use crate::ast::Expr_::*; + match &expr.node { + Var(ref var) => env.add_ref(var)?, + Add(ref a, ref b) => { + preprocess_expr(env, ds, a)?; + preprocess_expr(env, ds, b)?; + } + Sub(ref a, ref b) => { + preprocess_expr(env, ds, a)?; + preprocess_expr(env, ds, b)?; + } + Mul(ref a, ref b) => { + preprocess_expr(env, ds, a)?; + preprocess_expr(env, ds, b)?; + } + Div(ref a, ref b) => { + preprocess_expr(env, ds, a)?; + preprocess_expr(env, ds, b)?; + } + Mod(ref a, ref b) => { + preprocess_expr(env, ds, a)?; + preprocess_expr(env, ds, b)?; + } + Equal(ref a, ref b) => { + preprocess_expr(env, ds, a)?; + preprocess_expr(env, ds, b)?; + } + Assign(ref var, ref b) => { + preprocess_expr(env, ds, b)?; + env.declare(var.to_string()); + } + Function(ref fname, ref args) => { + env.add_ref(fname)?; + preprocess_expr(env, ds, args)?; + } + If(ref ifs) => { + for (cond, block) in ifs { + // TODO: could be optimized? + preprocess_expr(env, ds, cond)?; + for expr in block { + preprocess_expr(env, ds, expr)?; + } + } + } + List(list) => { + for entry in list { + preprocess_expr(env, ds, entry)?; + } + } + Dict(d) => { + for (key, val_uninterpreted) in d { + preprocess_expr(env, ds, val_uninterpreted)?; + } + } + Return(e) => { + preprocess_expr(env, ds, e)?; + } + Bool(_lit) => (), + Number(_lit) => (), + String(_lit) => (), + }; + Ok(()) +} diff --git a/aw-query/src/varenv.rs b/aw-query/src/varenv.rs new file mode 100644 index 00000000..a8493394 --- /dev/null +++ b/aw-query/src/varenv.rs @@ -0,0 +1,102 @@ +use std::collections::HashMap; + +use crate::datatype::DataType; +use crate::QueryError; + +struct Var { + pub refs: u32, + pub val: Option, +} + +pub struct VarEnv { + vars: HashMap, +} + +impl VarEnv { + pub fn new() -> Self { + VarEnv { + vars: HashMap::new(), + } + } + + pub fn declare(&mut self, name: String) -> () { + if !self.vars.contains_key(&name) { + let var = Var { refs: 0, val: None }; + println!("declare {}", name); + self.vars.insert(name, var); + } + } + + pub fn declare_static(&mut self, name: String, val: DataType) -> () { + let var = Var { + refs: std::u32::MAX, + val: Some(val), + }; + self.vars.insert(name, var); + } + + // TODO: rename assign? + pub fn insert(&mut self, name: String, val: DataType) -> () { + match self.vars.get_mut(&name) { + Some(var) => var.val = Some(val), + None => panic!(format!("fail, not declared {}", name)), // TODO: Properly handle this + }; + // Return is a little special that it's always taken at the end of the interpretation + if (name == "RETURN") { + self.add_ref("RETURN"); + } + } + + pub fn add_ref(&mut self, name: &str) -> Result<(), QueryError> { + match self.vars.get_mut(name) { + Some(var) => { + if var.refs != std::u32::MAX { + println!("add ref {}, {}", name, var.refs); + var.refs += 1 + } + } + None => return Err(QueryError::VariableNotDefined(name.to_string())), + }; + Ok(()) + } + + pub fn take(&mut self, name: &str) -> Option { + let clone: bool = match self.vars.get_mut(name) { + Some(var) => { + println!("take {}: {}", name, var.refs); + var.refs -= 1; + var.refs > 0 + } + None => return None, + }; + if clone { + match self.vars.get(name) { + Some(var) => { + match var.val { + Some(ref val) => Some(val.clone()), + None => return None, + } + }, + None => return None, + } + } else { + match self.vars.remove(name) { + Some(var) => var.val, + None => return None, + } + } + } + + // TODO: Remove this completely, only needed for TIMEINTERVAL + pub fn deprecated_get(&self, var: &str) -> Option { + match self.vars.get(var) { + Some(var) => { + match var.val { + Some(ref val) => Some(val.clone()), + None => None, + } + }, + None => None, + } + } +} diff --git a/aw-query/tests/query.rs b/aw-query/tests/query.rs index 1ecac21b..ea235dc1 100644 --- a/aw-query/tests/query.rs +++ b/aw-query/tests/query.rs @@ -197,6 +197,19 @@ mod query_tests { aw_query::DataType::Number(n) => assert_eq!(n, 2.0), ref data => panic!("Wrong datatype, {:?}", data), }; + + // test missing return + let code = String::from(""); + match aw_query::query(&code, &interval, &ds) { + Ok(ok) => panic!(format!("Expected QueryError, got {:?}", ok)), + Err(e) => match e { + QueryError::EmptyQuery() => (), + qe => panic!(format!( + "Expected QueryError::EmptyQuery, got {:?}", + qe + )), + }, + }; } #[test]