Skip to content

Commit c19dd60

Browse files
committed
Merge branch 'compile-clause-early' of https://github.com/LPCIC/elpi into compile-clause-early
2 parents ac35f29 + b5db4f5 commit c19dd60

File tree

5 files changed

+34
-3
lines changed

5 files changed

+34
-3
lines changed

src/compiler.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ end = struct (* {{{ *)
718718
| Some "Map" -> Some Map
719719
| Some "Hash" -> Some HashMap
720720
| Some "DTree" -> Some DiscriminationTree
721-
| Some s -> error ~loc ("unknown indexing directive " ^ s) in
721+
| Some s -> error ~loc ("unknown indexing directive " ^ s ^ ". Valid ones are: Map, Hash, DTree.") in
722722
begin match r with
723723
| None -> aux_tatt (Some (Structured.Index(i,it))) rest
724724
| Some (Structured.Index _) -> duplicate_err "index"

src/discrimination_tree.ml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,14 @@ module Trie = struct
165165

166166
let is_empty x = x == empty
167167

168+
let rec replace p x = function
169+
| Node { data; other; listTailVariable; map } ->
170+
Node {
171+
data = data |> List.map (fun y -> if p y then x else y);
172+
other = other |> Option.map (replace p x);
173+
listTailVariable = listTailVariable |> Option.map (replace p x);
174+
map = map |> Ptmap.map (replace p x);
175+
}
168176

169177
let add (a : Path.t) v t =
170178
let max = ref 0 in
@@ -383,6 +391,7 @@ let retrieve cmp_data path { t } =
383391
let r = call (retrieve ~pos:0 path t) in
384392
Bl.of_list @@ List.sort cmp_data r
385393

394+
let replace p x i = { i with t = Trie.replace p x i.t }
386395

387396
module Internal = struct
388397
let kConstant = kConstant

src/discrimination_tree.mli

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ val empty_dt : 'b list -> 'a t
5151
*)
5252
val retrieve : ('a -> 'a -> int) -> Path.t -> 'a t -> 'a Bl.scan
5353

54+
val replace : ('a -> bool) -> 'a -> 'a t -> 'a t
55+
5456
(***********************************************************)
5557
(* Printers *)
5658
(***********************************************************)

src/runtime.ml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2710,7 +2710,25 @@ let add1clause ~depth { idx; time; times } ~time_dir ~insert ~empty ~cons ~copy
27102710
StrMap.add id clause.timestamp times in
27112711
let idx =
27122712
try
2713-
add1clause2 ~depth idx ~insert ~empty ~copy graft grafting_reference predicate clause (Ptmap.find predicate idx);
2713+
(* TODO: do this only at compile time *)
2714+
match graft with
2715+
| Some (Elpi_parser.Ast.Structured.Replace _) ->
2716+
Ptmap.map (function
2717+
| TwoLevelIndex {
2718+
argno; mode;
2719+
all_clauses;
2720+
flex_arg_clauses;
2721+
arg_idx;
2722+
} -> TwoLevelIndex {
2723+
argno; mode;
2724+
all_clauses = insert graft grafting_reference clause all_clauses;
2725+
flex_arg_clauses = insert graft grafting_reference clause flex_arg_clauses;
2726+
arg_idx = Ptmap.map (fun l -> insert graft grafting_reference clause l) arg_idx;
2727+
}
2728+
| BitHash { mode; args; args_idx } -> BitHash { mode; args; args_idx = Ptmap.map (fun l -> insert graft grafting_reference clause l) args_idx }
2729+
| IndexWithDiscriminationTree {mode; arg_depths; args_idx; } -> IndexWithDiscriminationTree {mode; arg_depths; args_idx = Discrimination_tree.replace (fun x -> x.timestamp = grafting_reference) clause args_idx; }
2730+
) idx
2731+
| _ -> add1clause2 ~depth idx ~insert ~empty ~copy graft grafting_reference predicate clause (Ptmap.find predicate idx);
27142732
with
27152733
| Not_found ->
27162734
match classify_clause_argno ~depth 0 [] clause.args with
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
:index (1) % "Hash" "DTree"
12
pred p o:int.
23
:name "replace_me"
34
p 1.
@@ -6,4 +7,5 @@ p 1.
67
p 2.
78

89
main :-
9-
std.findall (p _) [p 2].
10+
std.findall (p _) [p 2],
11+
std.findall (p 1) [].

0 commit comments

Comments
 (0)