Skip to content

Commit b933fab

Browse files
committed
[typeabbrev] typeabbrev can be "recursif" (see description)
``` typeabbrev (dl A) (list A). pred p i:dl (dl A). ``` is no more considered as an invalid (looping) type.
1 parent eb195c6 commit b933fab

File tree

4 files changed

+64
-32
lines changed

4 files changed

+64
-32
lines changed

src/compiler.ml

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ type type_abbrev_declaration = {
430430
tavalue : preterm;
431431
taparams : int;
432432
taloc : Loc.t;
433+
timestamp:int
433434
}
434435
[@@ deriving show, ord]
435436

@@ -1291,13 +1292,13 @@ let query_preterm_of_ast ~depth macros state (loc, t) =
12911292
end;
12921293
F.Map.add n (body,loc) m
12931294

1294-
let compile_type_abbrev lcs state { Ast.TypeAbbreviation.name; nparams; loc; value } =
1295+
let compile_type_abbrev geti lcs state { Ast.TypeAbbreviation.name; nparams; loc; value } =
12951296
let state, (taname, _) = Symbols.allocate_global_symbol state name in
12961297
let state, tavalue = preterms_of_ast ~on_type:true loc ~depth:lcs F.Map.empty state (fun ~depth:_ state x -> state, [loc,x]) value in
12971298
let tavalue = assert(List.length tavalue = 1); List.hd tavalue in
12981299
if tavalue.amap.nargs != 0 then
12991300
error ~loc ("type abbreviation for " ^ F.show name ^ " has unbound variables");
1300-
state, { taname; tavalue; taparams = nparams; taloc = loc }
1301+
state, { taname; tavalue; taparams = nparams; taloc = loc; timestamp = geti () }
13011302

13021303
let add_to_index_type_abbrev state m ({ taname; taloc; tavalue; taparams } as x) =
13031304
if C.Map.mem taname m then begin
@@ -1348,8 +1349,9 @@ let query_preterm_of_ast ~depth macros state (loc, t) =
13481349
C.Map.union (fun _ l1 l2 -> Some (Types.merge l1 l2)) t1 t2
13491350

13501351
let merge_type_abbrevs s m1 m2 =
1352+
let len = C.Map.cardinal m1 in
13511353
if C.Map.is_empty m2 then m1 else
1352-
C.Map.fold (fun _ v m -> add_to_index_type_abbrev s m v) m1 m2
1354+
C.Map.fold (fun _ (k:type_abbrev_declaration) m -> add_to_index_type_abbrev s m {k with timestamp=k.timestamp+len}) m2 m1
13531355

13541356
let rec toplevel_clausify loc ~depth state t =
13551357
let state, cl = map_acc (pi2arg loc ~depth []) state (R.split_conj ~depth t) in
@@ -1432,12 +1434,13 @@ let query_preterm_of_ast ~depth macros state (loc, t) =
14321434
C.Map.add k (Types.make v) m
14331435

14341436
let run (state : State.t) ~toplevel_macros p =
1437+
let geti = let i = ref ~-1 in fun () -> incr i; !i in
14351438
(* FIXME: otypes omodes - NO, rewrite spilling on data.term *)
14361439
let rec compile_program omacros lcs state { macros; types; type_abbrevs; modes; body } =
14371440
check_no_overlap_macros omacros macros;
14381441
let active_macros =
14391442
List.fold_left compile_macro omacros macros in
1440-
let state, type_abbrevs = map_acc (compile_type_abbrev lcs) state type_abbrevs in
1443+
let state, type_abbrevs = map_acc (compile_type_abbrev geti lcs) state type_abbrevs in
14411444
let type_abbrevs = List.fold_left (add_to_index_type_abbrev state) C.Map.empty type_abbrevs in
14421445
let state, types =
14431446
map_acc (compile_type lcs) state types in
@@ -1637,11 +1640,11 @@ let subst_amap state f { nargs; c2i; i2n; n2t; n2i } =
16371640
let _apply_subst_list f = apply_subst (fun x -> smart_map (f x))
16381641

16391642
let tabbrevs_map state f m =
1640-
C.Map.fold (fun _ { taname; tavalue; taparams; taloc } m ->
1643+
C.Map.fold (fun _ { taname; tavalue; taparams; taloc; timestamp } m ->
16411644
(* TODO: check for collisions *)
16421645
let taname = f taname in
16431646
let tavalue = smart_map_preterm ~on_type:true state f tavalue in
1644-
C.Map.add taname { taname; tavalue; taparams; taloc } m
1647+
C.Map.add taname { taname; tavalue; taparams; taloc; timestamp } m
16451648
) m C.Map.empty
16461649

16471650
let apply_subst_constant ?live_symbols =
@@ -2739,43 +2742,45 @@ let quote_syntax time new_state { WithMain.clauses; query; compiler_state } =
27392742
close_w_binder argc queryt query.amap]) in
27402743
new_state, clist, q
27412744

2742-
let unfold_type_abbrevs ~compiler_state lcs type_abbrevs { term; loc; amap } =
2743-
let find_opt c =
2744-
try Some (C.Map.find c type_abbrevs) with Not_found -> None in
2745-
let rec aux_tabbrv seen = function
2746-
| Const c as x ->
2745+
let unfold_type_abbrevs ~is_typeabbrev ~compiler_state lcs type_abbrevs { term; loc; amap } ttime =
2746+
let error_undefined ~t1 ~t2 c tavalue =
2747+
if is_typeabbrev && t1 <= t2 then
2748+
error (Format.asprintf "typeabbrev %a uses the undefined %s constant at %a" (R.Pp.ppterm 0 [] ~argsdepth:0 [||]) tavalue.term (Symbols.show compiler_state c) Util.Loc.pp tavalue.loc);
2749+
in
2750+
(* Printf.printf "Istypeabbrev %b\n" is_typeabbrev; *)
2751+
(* C.Map.iter (fun k v -> Format.printf "Looping %d %s %a %d\n%!" k (Symbols.show compiler_state k) pp_term v.tavalue.term v.timestamp) type_abbrevs; *)
2752+
let find_opt c = C.Map.find_opt c type_abbrevs in
2753+
let rec aux_tabbrv ttime = function
2754+
| Const c as x ->
27472755
begin match find_opt c with
2748-
| Some { tavalue; taparams } ->
2756+
| Some { tavalue; taparams; timestamp=time } ->
27492757
if taparams > 0 then
27502758
error ~loc ("type abbreviation " ^ Symbols.show compiler_state c ^ " needs " ^
27512759
string_of_int taparams ^ " arguments");
2752-
if C.Set.mem c seen then
2753-
error ~loc
2754-
("looping while unfolding type abbreviation for "^ Symbols.show compiler_state c);
2755-
aux_tabbrv (C.Set.add c seen) tavalue.term
2760+
error_undefined ttime time c tavalue;
2761+
aux_tabbrv time tavalue.term
27562762
| None -> x
27572763
end
27582764
| App(c,t,ts) as x ->
27592765
begin match find_opt c with
2760-
| Some { tavalue; taparams } ->
2766+
| Some { tavalue; taparams; timestamp=time } ->
27612767
let nargs = 1 + List.length ts in
27622768
if taparams > nargs then
27632769
error ~loc ("type abbreviation " ^ Symbols.show compiler_state c ^ " needs " ^
27642770
string_of_int taparams ^ " arguments, only " ^
27652771
string_of_int nargs ^ " are provided");
2766-
if C.Set.mem c seen then
2767-
error ~loc
2768-
("looping while unfolding type abbreviation for "^ Symbols.show compiler_state c);
2769-
aux_tabbrv (C.Set.add c seen) (R.deref_appuv ~from:lcs ~to_:lcs (t::ts) tavalue.term)
2772+
error_undefined ttime time c tavalue;
2773+
aux_tabbrv time (R.deref_appuv ~from:lcs ~to_:lcs (t::ts) tavalue.term)
27702774
| None ->
2771-
let t1 = aux_tabbrv seen t in
2772-
let ts1 = smart_map (aux_tabbrv seen) ts in
2775+
let t1 = aux_tabbrv ttime t in
2776+
let ts1 = smart_map (aux_tabbrv ttime) ts in
27732777
if t1 == t && ts1 == ts then x
27742778
else App(c,t1,ts1)
27752779
end
2780+
| Lam x -> Lam (aux_tabbrv ttime x)
27762781
| x -> x
27772782
in
2778-
{ term = aux_tabbrv C.Set.empty term; loc; amap; spilling = false }
2783+
{ term = aux_tabbrv ttime term; loc; amap; spilling = false }
27792784

27802785
let term_of_ast ~depth state text =
27812786
if State.get D.while_compiling state then
@@ -2797,25 +2802,29 @@ let static_check ~exec ~checker:(state,program)
27972802
({ WithMain.types; type_abbrevs; initial_depth; compiler_state } as q) =
27982803
let time = `Compiletime in
27992804
let state, p,q = quote_syntax time state q in
2805+
2806+
(* C.Map.iter (fun k ((v:type_abbrev_declaration),t) -> Format.printf "H %s %a %d\n%!" (Symbols.show state k)
2807+
pp_term v.tavalue.term t) type_abbrevs; *)
2808+
2809+
let state, talist =
2810+
C.Map.bindings type_abbrevs |>
2811+
map_acc (fun state (name, { tavalue; timestamp=ttime }) ->
2812+
let tavaluet = unfold_type_abbrevs ~is_typeabbrev:true ~compiler_state initial_depth type_abbrevs tavalue ttime in
2813+
let state, tavaluet = quote_preterm time ~compiler_state state ~on_type:true tavaluet in
2814+
state, App(colonec, D.C.of_string (Symbols.show compiler_state name), [lam2forall tavaluet])) state
2815+
in
28002816
let state, tlist = C.Map.fold (fun tname l (state,tl) ->
28012817
let l = l.Types.lst in
28022818
let state, l =
28032819
List.rev l |> map_acc (fun state { Types.decl = { ttype } } ->
28042820
let state, c = mkQCon time ~compiler_state state ~on_type:false tname in
2805-
let ttypet = unfold_type_abbrevs ~compiler_state initial_depth type_abbrevs ttype in
2821+
let ttypet = unfold_type_abbrevs ~is_typeabbrev:false ~compiler_state initial_depth type_abbrevs ttype 0 in
28062822
let state, ttypet = quote_preterm time ~compiler_state state ~on_type:true ttypet in
28072823
state, App(colonc,c, [close_w_binder forallc ttypet ttype.amap])) state
28082824
in
28092825
state, l :: tl)
28102826
types (state,[]) in
28112827
let tlist = List.concat (List.rev tlist) in
2812-
let state, talist =
2813-
C.Map.bindings type_abbrevs |>
2814-
map_acc (fun state (name, { tavalue; } ) ->
2815-
let tavaluet = unfold_type_abbrevs ~compiler_state initial_depth type_abbrevs tavalue in
2816-
let state, tavaluet = quote_preterm time ~compiler_state state ~on_type:true tavaluet in
2817-
state, App(colonec, D.C.of_string (Symbols.show compiler_state name), [lam2forall tavaluet])) state
2818-
in
28192828
let loc = Loc.initial "(static_check)" in
28202829
let query =
28212830
query_of_term (state, program) (fun ~depth state ->

tests/sources/typeabbrv13.elpi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
typeabbrev (dl A) (list A).
2+
3+
pred p i:dl (dl A).
4+
5+
main.
6+

tests/sources/typeabbrv14.elpi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
typeabbrev (dl A) (dl A).
2+
3+
pred p i:dl (dl A).
4+
5+
main.
6+

tests/suite/correctness_FO.ml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,17 @@ let () = declare "typeabbrv12"
9999
~description:"type abbreviations and error messages"
100100
()
101101

102+
let () = declare "typeabbrv13"
103+
~source_elpi:"typeabbrv13.elpi"
104+
~description:"type abbreviations"
105+
()
106+
107+
let () = declare "typeabbrv14"
108+
~source_elpi:"typeabbrv14.elpi"
109+
~description:"type abbreviations"
110+
~expectation:(FailureOutput (Str.regexp "SYMBOL.*uses the undefined dl constant"))
111+
()
112+
102113
let () = declare "conj2"
103114
~source_elpi:"conj2.elpi"
104115
~description:"parsing and evaluation of & (binary conj)"

0 commit comments

Comments
 (0)