From f17f2b7dec74a372facbce6cd946571553277538 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Basile=20Cl=C3=A9ment?= Date: Wed, 29 May 2024 11:52:55 +0200 Subject: [PATCH] feat(BV): Only store domains on variable parts This patch changes the way that bit-vector domains are stored in order to share domains between multiple bit-vectors with the same variable part, including across different bit-width. Currently, if we had terms [x], [00 @ x], and [10 @ x], we would store domains for each of these. With this patch, we only store a domain for [x] and rebuild the domains for [00 @ x] and [10 @ x] dynamically; we do this by computing normal forms composed of a variable part ([x] here) and a constant offset (either [0] or [10]). This reduces the number of variables handled by the system (which is important for global domains such as global difference) and reduces the number of trivial propagations. Note that this patch is not reverting #1044. In particular, we are considering [x @ y] as a unique (composite) variable built up from the atomic variables [x] and [y] (in the same way that the polynomial [x + y] is built up from the monomial [x] and [y] in the IntervalCalculus module; in fact, the implementation is intentionally abstracted in order to be useable with polynomials). This is necessary because, as noted in #1044, the interval domain for [x @ y] cannot be built reconstructed precisely from the interval domains for [x] and [y]. The patch includes a bit of refactoring of the way we handle constraints; notably, instead of aggressively substituting constraints, we now store the original constraint and call [Uf.find] at propagation time. This avoids having to define a substitution operation on normal forms and allows storing the constraint dependencies directly inside the domain (in the previous design, we could not do it easily without repeating the expensive substitution of constraint arguments). --- src/lib/reasoners/adt_rel.ml | 6 +- src/lib/reasoners/bitv_rel.ml | 1368 ++++++++++++++++++------------ src/lib/reasoners/rel_utils.ml | 1443 ++++++++++++++++++++------------ src/lib/reasoners/uf.ml | 4 +- src/lib/reasoners/uf.mli | 4 +- 5 files changed, 1755 insertions(+), 1070 deletions(-) diff --git a/src/lib/reasoners/adt_rel.ml b/src/lib/reasoners/adt_rel.ml index 18eb871b3f..0009c32989 100644 --- a/src/lib/reasoners/adt_rel.ml +++ b/src/lib/reasoners/adt_rel.ml @@ -191,7 +191,7 @@ module Domains = struct with Not_found -> Domain.unknown (X.type_info r) - let add r t = + let init r t = match Th.embed r with | Alien _ when not (MX.mem r t.domains) -> (* We have to add a default domain if the key `r` is not in map in order @@ -236,7 +236,7 @@ module Domains = struct let t = remove r t in tighten nr nd t - | exception Not_found -> add nr t + | exception Not_found -> init nr t (* [propagate f a t] iterates on all the changed domains of [t] since the last call of [propagate]. The list of changed domains is flushed after @@ -431,7 +431,7 @@ let add r uf domains = | Ty.Tadt _ -> Debug.add r; let rr, _ = Uf.find_r uf r in - Domains.add rr domains + Domains.init rr domains | _ -> domains diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index a6b8bc01b5..46cf696396 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -89,6 +89,12 @@ module Interval_domain = struct let add_explanation = Intervals.Int.add_explanation + type constant = Z.t + + let constant n = Intervals.Int.of_bounds (Closed n) (Closed n) + + let filter_ty = is_bv_ty + let unknown = function | Ty.Tbitv n -> Intervals.Int.of_bounds @@ -97,8 +103,6 @@ module Interval_domain = struct Fmt.invalid_arg "unknown: only bit-vector types are supported; got %a" Ty.print ty - let filter_ty = is_bv_ty - let intersect x y = match Intervals.Int.intersect x y with | Empty ex -> @@ -108,128 +112,131 @@ module Interval_domain = struct let lognot sz int = Intervals.Int.extract ~ofs:0 ~len:sz @@ Intervals.Int.lognot int - let fold_signed f { Bitv.value; negated } sz int acc = - f value (if negated then lognot sz int else int) acc - - let point ?ex n = - Intervals.Int.of_bounds ?ex (Closed n) (Closed n) - - let fold_leaves f r int acc = - let width = bitwidth r in - let j, acc = - List.fold_left (fun (j, acc) { Bitv.bv; sz } -> - (* sz = j - i + 1 => i = j - sz + 1 *) - let int = Intervals.Int.extract int ~ofs:(j - sz + 1) ~len:sz in - - let acc = match bv with - | Bitv.Cte z -> - (* Nothing to update, but still check for consistency *) - ignore @@ intersect int (point z); - acc - | Other s -> fold_signed f s sz int acc - | Ext (r, r_size, i, j) -> - (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) - assert (i + r_size - j - 1 + sz = r_size); - let lo = unknown (Tbitv i) in - let int = Intervals.Int.scale Z.(~$1 lsl i) int in - let hi = unknown (Tbitv (r_size - j - 1)) in - let hi = - Intervals.Int.scale Z.(~$1 lsl Stdlib.(i + sz)) hi - in - fold_signed f r r_size Intervals.Int.(add hi (add int lo)) acc - in - (j - sz), acc - ) (width - 1, acc) (Shostak.Bitv.embed r) - in - assert (j = -1); - acc + let add_offset d cte = + Intervals.Int.add d (Intervals.Int.of_bounds (Closed cte) (Closed cte)) - let map_signed f { Bitv.value; negated } sz = - if negated then lognot sz (f value) else f value - - let map_leaves f r = - List.fold_left (fun ival { Bitv.bv; sz } -> - let ival = Intervals.Int.scale Z.(~$1 lsl sz) ival in - Intervals.Int.add ival @@ - match bv with - | Bitv.Cte z -> point z - | Other s -> map_signed f s sz - | Ext (s, sz', i, j) -> - Intervals.Int.extract (map_signed f s sz') ~ofs:i ~len:(j - i + 1) - ) (point Z.zero) (Shostak.Bitv.embed r) + let sub_offset d cte = + Intervals.Int.sub d (Intervals.Int.of_bounds (Closed cte) (Closed cte)) end -module Interval_domains = Rel_utils.Domains_make(Interval_domain) +type 'a explained = { value : 'a ; explanation : Explanation.t } -module Bitlist_domain : Rel_utils.Domain with type t = Bitlist.t = struct - (* Note: these functions are not in [Bitlist] proper in order to avoid a - (direct) dependency from [Bitlist] to the [Shostak] module. *) +let explained ~ex value = { value ; explanation = ex } - include Bitlist +module ExplainedOrdered(V : Rel_utils.OrderedType) : + Rel_utils.OrderedType with type t = V.t explained = +struct + type t = V.t explained - let filter_ty = is_bv_ty - - let fold_signed sz f { Bitv.value; negated } bl acc = - let bl = if negated then extract (lognot bl) 0 sz else bl in - f value bl acc + let pp ppf { value; _ } = V.pp ppf value - let fold_leaves f r bl acc = - let sz = bitwidth r in - let (acc, _, _) = List.fold_left (fun (acc, bl, w) { Bitv.bv; sz } -> - (* Extract the bitlist associated with the current component *) - let mid = w - sz in - let bl_tail = extract bl 0 mid in - let bl = extract bl mid (w - mid) in + let compare { value = v1; _ } { value = v2; _ } = V.compare v1 v2 - match bv with - | Bitv.Cte z -> - assert (Z.numbits z <= sz); - (* Nothing to update, but still check for consistency! *) - ignore @@ intersect bl (exact z Ex.empty); - acc, bl_tail, mid - | Other r -> fold_signed sz f r bl acc, bl_tail, mid - | Ext (r, r_size, i, j) -> - (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) - assert (i + r_size - j - 1 + w - mid = r_size); - let hi = Bitlist.(extract unknown 0 (r_size - j - 1)) in - let lo = Bitlist.(extract unknown 0 i) in - let bl_hd = Bitlist.((hi lsl (j + 1)) lor (bl lsl i) lor lo) in - fold_signed r_size f r bl_hd acc, - bl_tail, - mid - ) (acc, bl, sz) (Shostak.Bitv.embed r) - in acc + module Set = Set.Make(struct + type nonrec t = t - let map_signed sz f { Bitv.value; negated } = - let bl = f value in - if negated then extract (lognot bl) 0 sz else bl + let compare = compare + end) - let map_leaves f r = - List.fold_left (fun bl { Bitv.bv; sz } -> - bl lsl sz lor - match bv with - | Bitv.Cte z -> extract (exact z Ex.empty) 0 sz - | Other r -> map_signed sz f r - | Ext (r, r_sz, i, j) -> - extract (map_signed r_sz f r) i (j - i + 1) - ) (exact Z.zero Ex.empty) (Shostak.Bitv.embed r) + module Map = Map.Make(struct + type nonrec t = t - let unknown = function - | Ty.Tbitv n -> extract unknown 0 n - | _ -> - (* Only bit-vector values can have bitlist domains. *) - invalid_arg "unknown" + let compare = compare + end) end -module Bitlist_domains = Rel_utils.Domains_make(Bitlist_domain) +module BitvNormalForm = struct + (** Normal form for bit-vector values. + + We decompose non-constant bit-vector compositions as a variable part, + where all constant bits are set to [0] and all high constant bits are + chopped off, and an offset with all the constant bits. We consider the + variable part atomic if it is a single non-negated variable. + + Assuming [x] and [y] are bit-vectors of width 2: + - [101 @ x] is [x + 10100] ; + - [10 @ x @ 01] is [(x @ 00) + 100001] ; + - [10 @ y<0, 0> @ y<1, 1>] is [(y<0, 0> @ y<1>1) + 1000] ; + - [10 @ x @ 11 @ y] is [(x @ 00 @ y) + 10001100] *) + + type constant = Z.t + + type atom = X.r + + type composite = X.r + + type t = + | Constant of constant + | Atom of atom * constant + | Composite of composite * constant + + type expr = X.r + + let normal_form r = + let rec loop cte rev_acc = function + | [] -> ( + match rev_acc with + | [] -> + Constant cte + | [ { Bitv.bv = Bitv.Other { value ; negated = false }; _ } ] -> + Atom (value, cte) + | _ -> + Composite (Shostak.Bitv.is_mine (List.rev rev_acc), cte) + ) + | { Bitv.bv = Bitv.Cte n ; sz } :: bv' -> + let cte = Z.(cte lsl sz lor n) in + let acc = + match rev_acc with + | [] -> [] + | _ -> { Bitv.bv = Bitv.Cte Z.zero ; sz } :: rev_acc + in + loop cte acc bv' + | x :: bv' -> + let cte = Z.(cte lsl x.sz) in + loop cte (x :: rev_acc) bv' + in loop Z.zero [] (Shostak.Bitv.embed r) +end module Constraint : sig - include Rel_utils.Constraint + type binop = + (* Bitwise operations *) + | Band | Bor | Bxor + (* Arithmetic operations *) + | Badd | Bmul | Budiv | Burem + (* Shift operations *) + | Bshl | Blshr + + type fun_t = + | Fbinop of binop * X.r * X.r + + type binrel = Rule | Rugt + + type rel_t = + | Rbinrel of binrel * X.r * X.r + + type view = + | Cfun of X.r * fun_t + | Crel of rel_t + + type t + + val view : t -> view + + val pp : t Fmt.t + (** Pretty-printer for constraints. *) val equal : t -> t -> bool val hash : t -> int + val compare : t -> t -> int + (** Comparison function for constraints. The comparison function is + arbitrary and has no semantic meaning. You should not depend on any of + its properties, other than it defines an (arbitrary) total order on + constraint representations. *) + + val fold_args : (X.r -> 'a -> 'a) -> t -> 'a -> 'a + val bvand : X.r -> X.r -> X.r -> t (** [bvand x y z] is the constraint [x = y & z] *) @@ -268,14 +275,6 @@ module Constraint : sig val bvule : X.r -> X.r -> t val bvugt : X.r -> X.r -> t - - val propagate_bitlist : ex:Ex.t -> t -> Bitlist_domains.Ephemeral.t -> unit - (** [propagate ~ex t dom] propagates the constraint [t] in domain [dom]. - - The explanation [ex] justifies that the constraint [t] applies, and must - be added to any domain that gets updated during propagation. *) - - val propagate_interval : ex:Ex.t -> t -> Interval_domains.Ephemeral.t -> unit end = struct type binop = (* Bitwise operations *) @@ -330,9 +329,327 @@ end = struct | Band | Bor | Bxor | Badd | Bmul -> true | Budiv | Burem | Bshl | Blshr -> false + type fun_t = + | Fbinop of binop * X.r * X.r + + let pp_fun_t ppf = function + | Fbinop (op, x, y) -> + Fmt.pf ppf "%a@[(%a,@ %a)@]" pp_binop op X.print x X.print y + + let equal_fun_t f1 f2 = + match f1, f2 with + | Fbinop (op1, x1, y1), Fbinop (op2, x2, y2) -> + equal_binop op1 op2 && X.equal x1 x2 && X.equal y1 y2 + + let hash_fun_t = function + | Fbinop (op, x, y) -> Hashtbl.hash (hash_binop op, X.hash x, X.hash y) + + let normalize_fun_t = function + | Fbinop (op, x, y) when is_commutative op && X.hash_cmp x y > 0 -> + Fbinop (op, y, x) + | Fbinop _ as e -> e + + type binrel = Rule | Rugt + + let pp_binrel ppf = function + | Rule -> Fmt.pf ppf "bvule" + | Rugt -> Fmt.pf ppf "bvugt" + + let equal_binrel : binrel -> binrel -> bool = Stdlib.(=) + + let hash_binrel : binrel -> int = Hashtbl.hash + + type rel_t = + | Rbinrel of binrel * X.r * X.r + + let pp_rel_t ppf = function + | Rbinrel (op, x, y) -> + Fmt.pf ppf "%a@[(%a,@ %a)@]" pp_binrel op X.print x X.print y + + let equal_rel_t f1 f2 = + match f1, f2 with + | Rbinrel (op1, x1, y1), Rbinrel (op2, x2, y2) -> + equal_binrel op1 op2 && X.equal x1 x2 && X.equal y1 y2 + + let hash_rel_t = function + | Rbinrel (op, x, y) -> Hashtbl.hash (hash_binrel op, X.hash x, X.hash y) + + let normalize_rel_t = function + (* No normalization for relations yet *) + | r -> r + + type view = + | Cfun of X.r * fun_t + | Crel of rel_t + + let pp_view ppf = function + | Cfun (r, fn) -> + Fmt.(pf ppf "%a =@ %a" (box X.print) r (box pp_fun_t) fn) + | Crel rel -> + pp_rel_t ppf rel + + let equal_view c1 c2 = + match c1, c2 with + | Cfun (r1, f1), Cfun (r2, f2) -> + X.equal r1 r2 && equal_fun_t f1 f2 + | Cfun _, _ | _, Cfun _ -> false + + | Crel r1, Crel r2 -> + equal_rel_t r1 r2 + + let hash_view = function + | Cfun (r, f) -> Hashtbl.hash (0, X.hash r, hash_fun_t f) + | Crel r -> Hashtbl.hash (1, hash_rel_t r) + + let normalize_view = function + | Cfun (r, f) -> Cfun (r, normalize_fun_t f) + | Crel r -> Crel (normalize_rel_t r) + + type t = { view : view ; mutable tag : int } + + let view { view ; _ } = view + + let pp ppf { view; _ } = pp_view ppf view + + module W = Weak.Make(struct + type nonrec t = t + + let equal c1 c2 = equal_view c1.view c2.view + + let hash c = hash_view c.view + end) + + let hcons = + let cnt = ref 0 in + let tbl = W.create 17 in + fun view -> + let view = normalize_view view in + let tagged = W.merge tbl { view ; tag = -1 } in + if tagged.tag = -1 then ( + tagged.tag <- !cnt; + incr cnt + ); + tagged + + let cfun r f = hcons @@ Cfun (r, f) + + let cbinop op r x y = cfun r (Fbinop (op, x, y)) + + let bvand = cbinop Band + let bvor = cbinop Bor + let bvxor = cbinop Bxor + let bvadd = cbinop Badd + let bvsub r x y = + (* r = x - y <-> x = r + y *) + bvadd x r y + let bvmul = cbinop Bmul + let bvudiv = cbinop Budiv + let bvurem = cbinop Burem + let bvshl = cbinop Bshl + let bvlshr = cbinop Blshr + + let crel r = hcons @@ Crel r + + let cbinrel op x y = crel (Rbinrel (op, x, y)) + + let bvule = cbinrel Rule + let bvugt = cbinrel Rugt + + let equal c1 c2 = c1.tag = c2.tag + + let hash c = Hashtbl.hash c.tag + + let compare c1 c2 = Int.compare c1.tag c2.tag + + let fold_args_fun_t f fn acc = + match fn with + | Fbinop (_, x, y) -> f y (f x acc) + + let fold_args_rel_t f r acc = + match r with + | Rbinrel (_op, x, y) -> f y (f x acc) + + let fold_args_view f c acc = + match c with + | Cfun (r, fn) -> fold_args_fun_t f fn (f r acc) + | Crel r -> fold_args_rel_t f r acc + + let fold_args f c acc = fold_args_view f (view c) acc +end + + +module EC = ExplainedOrdered(struct + include Constraint + + module Set = Set.Make(Constraint) + module Map = Map.Make(Constraint) + end) + +module CompositeIntervalDomain = struct + type var = X.r + + type atom = X.r + + type domain = Interval_domain.t + + let map_signed f { Bitv.value; negated } sz = + if negated then Interval_domain.lognot sz (f value) else f value + + let map_domain f r = + List.fold_left (fun ival { Bitv.bv; sz } -> + let ival = Intervals.Int.scale Z.(~$1 lsl sz) ival in + Intervals.Int.add ival @@ + match bv with + | Bitv.Cte z -> Interval_domain.constant z + | Other s -> map_signed f s sz + | Ext (s, sz', i, j) -> + Intervals.Int.extract (map_signed f s sz') ~ofs:i ~len:(j - i + 1) + ) (Interval_domain.constant Z.zero) (Shostak.Bitv.embed r) +end + +module XComposite = struct + include Rel_utils.XComparable + + type atom = X.r + + let fold f r acc = + List.fold_left (fun acc { Bitv.bv ; _ } -> + match bv with + | Bitv.Cte _ -> acc + | Other { value ; _ } -> f value acc + | Ext ({ value ; _ }, _, _, _) -> f value acc + ) acc (Shostak.Bitv.embed r) +end + +module XAtom = struct + include Rel_utils.XComparable + + let type_info = X.type_info +end + +module Interval_domains = + Rel_utils.Domains_make + (Interval_domain) + (XAtom) + (XComposite) + (CompositeIntervalDomain) + (BitvNormalForm) + (EC) + +module Interval_domains_uf = + Rel_utils.UfHandle + (Interval_domain) + (Interval_domains.Ephemeral) + +module Bitlist_domain = struct + (* Note: these functions are not in [Bitlist] proper in order to avoid a + (direct) dependency from [Bitlist] to the [Shostak] module. *) + + include Bitlist + + type constant = Z.t + + let constant n = exact n Ex.empty + + let filter_ty = is_bv_ty + + let unknown = function + | Ty.Tbitv n -> extract unknown 0 n + | _ -> + (* Only bit-vector values can have bitlist domains. *) + invalid_arg "unknown" + + let add_offset d cte = + Bitlist.logor d (Bitlist.exact cte Explanation.empty) + + let sub_offset d cte = + let cte = Bitlist.exact cte Explanation.empty in + Bitlist.logand d (Bitlist.lognot cte) +end + +module CompositeBitlistDomain = struct + type var = X.r + + type atom = X.r + + type domain = Bitlist_domain.t + + let map_signed sz f { Bitv.value; negated } = + let bl = f value in + if negated then Bitlist.extract (Bitlist.lognot bl) 0 sz else bl + + let map_domain f r = + List.fold_left (fun bl { Bitv.bv; sz } -> + let open Bitlist in + bl lsl sz lor + match bv with + | Bitv.Cte z -> extract (Bitlist_domain.constant z) 0 sz + | Other r -> map_signed sz f r + | Ext (r, r_sz, i, j) -> + extract (map_signed r_sz f r) i (j - i + 1) + ) (Bitlist_domain.constant Z.zero) (Shostak.Bitv.embed r) +end + +module Bitlist_domains = + Rel_utils.Domains_make + (Bitlist_domain) + (XAtom) + (XComposite) + (CompositeBitlistDomain) + (BitvNormalForm) + (EC) + +module Bitlist_domains_uf = + Rel_utils.UfHandle + (Bitlist_domain) + (Bitlist_domains.Ephemeral) + +(** The ['c acts] type is used to register new facts and constraints in + [Propagator.simplify]. *) +type 'c acts = + { acts_add_lit_view : ex:Explanation.t -> X.r L.view -> unit + (** Assert a semantic literal. *) + ; acts_add_eq : ex:Explanation.t -> X.r -> X.r -> unit + (** Assert equality between two semantic values. *) + ; acts_add_constraint : ex:Explanation.t -> 'c -> unit + (** Assert a new constraint. *) + } + +module Propagator : sig + type t = Constraint.t + (** The type of constraints. + + Constraints apply to semantic values of type [X.r] as arguments. *) + + val simplify : Uf.t -> t -> t acts -> bool + (** [simplify c acts] simplifies the constraint [c] by calling appropriate + functions on [acts]. + + {b Note}: All the facts and constraints added through [acts] must be + logically implied by [c] {b only}. Doing otherwise is a {b soundness bug}. + + Returns [true] if the constraint has been fully simplified and can + be removed, and [false] otherwise. + + {b Note}: Returning [true] will cause the constraint to be removed, even + if it was re-added with [acts_add_constraint]. If you want to add new + facts/constraints but keep the existing constraint (usually a bad idea), + return [false] instead. *) + + val propagate_bitlist : Bitlist_domains_uf.t -> ex:Ex.t -> t -> unit + (** [propagate dom ~ex t] propagates the constraint [t] in domain [dom]. + + The explanation [ex] justifies that the constraint [t] applies, and must + be added to any domain that gets updated during propagation. *) + + val propagate_interval : + Interval_domains_uf.t -> ex:Ex.t -> t -> unit +end = struct + include Constraint + let propagate_binop ~ex sz dx op dy dz = - let open Bitlist_domains.Ephemeral in let norm bl = Bitlist.extract bl 0 sz in + let open Rel_utils.HandleNotations(Bitlist_domain)(Bitlist_domains_uf) in match op with | Band -> update ~ex dx @@ norm @@ Bitlist.logand !!dy !!dz; @@ -383,7 +700,7 @@ end = struct () let propagate_interval_binop ~ex sz dr op dx dy = - let open Interval_domains.Ephemeral in + let open Rel_utils.HandleNotations(Interval_domain)(Interval_domains_uf) in let norm i = Intervals.Int.extract i ~ofs:0 ~len:sz in match op with | Badd -> @@ -410,58 +727,20 @@ end = struct (* No interval propagation for bitwise operators yet *) () - type fun_t = - | Fbinop of binop * X.r * X.r - - let pp_fun_t ppf = function - | Fbinop (op, x, y) -> - Fmt.pf ppf "%a@[(%a,@ %a)@]" pp_binop op X.print x X.print y - - let equal_fun_t f1 f2 = - match f1, f2 with - | Fbinop (op1, x1, y1), Fbinop (op2, x2, y2) -> - equal_binop op1 op2 && X.equal x1 x2 && X.equal y1 y2 - - let hash_fun_t = function - | Fbinop (op, x, y) -> Hashtbl.hash (hash_binop op, X.hash x, X.hash y) - - let normalize_fun_t = function - | Fbinop (op, x, y) when is_commutative op && X.hash_cmp x y > 0 -> - Fbinop (op, y, x) - | Fbinop _ as e -> e - - let fold_args_fun_t f fn acc = - match fn with - | Fbinop (_, x, y) -> f y (f x acc) - - let subst_fun_t rr nrr = function - | Fbinop (op, x, y) -> Fbinop (op, X.subst rr nrr x, X.subst rr nrr y) - let propagate_fun_t ~ex dom r f = - let open Bitlist_domains.Ephemeral in - let get r = handle dom r in + let get r = Bitlist_domains_uf.entry dom r in match f with | Fbinop (op, x, y) -> let n = bitwidth r in propagate_binop ~ex n (get r) op (get x) (get y) let propagate_interval_fun_t ~ex dom r f = - let get r = Interval_domains.Ephemeral.handle dom r in + let get r = Interval_domains_uf.entry dom r in match f with | Fbinop (op, x, y) -> let sz = bitwidth r in propagate_interval_binop ~ex sz (get r) op (get x) (get y) - type binrel = Rule | Rugt - - let pp_binrel ppf = function - | Rule -> Fmt.pf ppf "bvule" - | Rugt -> Fmt.pf ppf "bvugt" - - let equal_binrel : binrel -> binrel -> bool = Stdlib.(=) - - let hash_binrel : binrel -> int = Hashtbl.hash - let propagate_binrel ~ex:_ _op _dx _dy = (* No bitlist propagation for relations yet *) () @@ -477,7 +756,7 @@ end = struct Intervals.Int.of_bounds ~ex:(Ex.union ex ex') inf Unbounded let propagate_less_than ~ex ~strict dx dy = - let open Interval_domains.Ephemeral in + let open Rel_utils.HandleNotations(Interval_domain)(Interval_domains_uf) in (* Do not use [update] to make sure that the justification is only stored on the upper/lower bound. *) update ~ex:Ex.empty dx (less_than_sup ~ex ~strict !!dy); @@ -490,153 +769,31 @@ end = struct | Rugt -> propagate_less_than ~ex ~strict:true dy dx - type rel_t = - | Rbinrel of binrel * X.r * X.r - - let pp_rel_t ppf = function - | Rbinrel (op, x, y) -> - Fmt.pf ppf "%a@[(%a,@ %a)@]" pp_binrel op X.print x X.print y - - let equal_rel_t f1 f2 = - match f1, f2 with - | Rbinrel (op1, x1, y1), Rbinrel (op2, x2, y2) -> - equal_binrel op1 op2 && X.equal x1 x2 && X.equal y1 y2 - - let hash_rel_t = function - | Rbinrel (op, x, y) -> Hashtbl.hash (hash_binrel op, X.hash x, X.hash y) - - let normalize_rel_t = function - (* No normalization for relations yet *) - | r -> r - - let fold_args_rel_t f r acc = - match r with - | Rbinrel (_op, x, y) -> f y (f x acc) - - let subst_rel_t rr nrr = function - | Rbinrel (op, x, y) -> Rbinrel (op, X.subst rr nrr x, X.subst rr nrr y) - let propagate_rel_t ~ex dom r = - let open Bitlist_domains.Ephemeral in - let get r = handle dom r in + let get r = Bitlist_domains_uf.entry dom r in match r with | Rbinrel (op, x, y) -> propagate_binrel ~ex op (get x) (get y) let propagate_interval_rel_t ~ex dom r = - let get r = Interval_domains.Ephemeral.handle dom r in + let get r = Interval_domains_uf.entry dom r in match r with | Rbinrel (op, x, y) -> propagate_interval_binrel ~ex op (get x) (get y) - type repr = - | Cfun of X.r * fun_t - | Crel of rel_t - - let pp_repr ppf = function - | Cfun (r, fn) -> - Fmt.(pf ppf "%a =@ %a" (box X.print) r (box pp_fun_t) fn) - | Crel rel -> - pp_rel_t ppf rel - - let equal_repr c1 c2 = - match c1, c2 with - | Cfun (r1, f1), Cfun (r2, f2) -> - X.equal r1 r2 && equal_fun_t f1 f2 - | Cfun _, _ | _, Cfun _ -> false - - | Crel r1, Crel r2 -> - equal_rel_t r1 r2 - - let hash_repr = function - | Cfun (r, f) -> Hashtbl.hash (0, X.hash r, hash_fun_t f) - | Crel r -> Hashtbl.hash (1, hash_rel_t r) - - let normalize_repr = function - | Cfun (r, f) -> Cfun (r, normalize_fun_t f) - | Crel r -> Crel (normalize_rel_t r) - - let fold_args_repr f c acc = - match c with - | Cfun (r, fn) -> fold_args_fun_t f fn (f r acc) - | Crel r -> fold_args_rel_t f r acc - - let subst_repr rr nrr = function - | Cfun (r, f) -> Cfun (X.subst rr nrr r, subst_fun_t rr nrr f) - | Crel r -> Crel (subst_rel_t rr nrr r) - - let propagate_repr ~ex dom = function + let propagate_view ~ex dom = function | Cfun (r, f) -> propagate_fun_t ~ex dom r f | Crel r -> propagate_rel_t ~ex dom r - let propagate_interval_repr ~ex dom = function + let propagate_interval_view ~ex dom = function | Cfun (r, f) -> propagate_interval_fun_t ~ex dom r f | Crel r -> propagate_interval_rel_t ~ex dom r - type t = { repr : repr ; mutable tag : int } - - let pp ppf { repr; _ } = pp_repr ppf repr - - module W = Weak.Make(struct - type nonrec t = t - - let equal c1 c2 = equal_repr c1.repr c2.repr - - let hash c = hash_repr c.repr - end) - - let hcons = - let cnt = ref 0 in - let tbl = W.create 17 in - fun repr -> - let repr = normalize_repr repr in - let tagged = W.merge tbl { repr ; tag = -1 } in - if tagged.tag = -1 then ( - tagged.tag <- !cnt; - incr cnt - ); - tagged - - let cfun r f = hcons @@ Cfun (r, f) - - let cbinop op r x y = cfun r (Fbinop (op, x, y)) - - let bvand = cbinop Band - let bvor = cbinop Bor - let bvxor = cbinop Bxor - let bvadd = cbinop Badd - let bvsub r x y = - (* r = x - y <-> x = r + y *) - bvadd x r y - let bvmul = cbinop Bmul - let bvudiv = cbinop Budiv - let bvurem = cbinop Burem - let bvshl = cbinop Bshl - let bvlshr = cbinop Blshr - - let crel r = hcons @@ Crel r - - let cbinrel op x y = crel (Rbinrel (op, x, y)) - - let bvule = cbinrel Rule - let bvugt = cbinrel Rugt - - let equal c1 c2 = c1.tag = c2.tag - - let hash c = Hashtbl.hash c.tag - - let compare c1 c2 = Int.compare c1.tag c2.tag - - let fold_args f c acc = fold_args_repr f c.repr acc - - let subst rr nrr c = - hcons @@ subst_repr rr nrr c.repr + let propagate_bitlist dom ~ex c = + propagate_view ~ex dom (view c) - let propagate_bitlist ~ex c dom = - propagate_repr ~ex dom c.repr - - let propagate_interval ~ex c dom = - propagate_interval_repr ~ex dom c.repr + let propagate_interval dom ~ex c = + propagate_interval_view ~ex dom (view c) let const sz n = Shostak.Bitv.is_mine [ { bv = Cte (Z.extract n 0 sz); sz } ] @@ -652,59 +809,59 @@ end = struct | _ -> invalid_arg "const_value" (* Add the constraint: r = x *) - let add_eq acts r x = - acts.Rel_utils.acts_add_eq r x + let add_eq ~ex acts r x = + acts.acts_add_eq ~ex r x (* Add the constraint: r = c *) - let add_eq_const acts r c = - add_eq acts r @@ const (bitwidth r) c + let add_eq_const ~ex acts r c = + add_eq ~ex acts r @@ const (bitwidth r) c (* Add the constraint: r = x & c *) - let add_and_const acts r x c = + let add_and_const ~ex acts r x c = (* TODO: apply to extractions for any [c]? Could be expensive for Shostak *) if Z.equal c Z.zero then ( - add_eq_const acts r Z.zero; + add_eq_const ~ex acts r Z.zero; true ) else if Z.equal c (Z.extract Z.minus_one 0 (bitwidth r)) then ( - add_eq acts r x; + add_eq ~ex acts r x; true ) else false (* Add the constraint: r = x | c *) - let add_or_const acts r x c = + let add_or_const ~ex acts r x c = (* TODO: apply to extractions for any [c]? Could be expensive for Shostak *) if Z.equal c Z.zero then ( - add_eq acts r x; + add_eq ~ex acts r x; true ) else if Z.equal c (Z.extract Z.minus_one 0 (bitwidth r)) then ( - add_eq_const acts r Z.minus_one; + add_eq_const ~ex acts r Z.minus_one; true ) else false (* Add the constraint: r = x ^ c *) - let add_xor_const acts r x c = + let add_xor_const ~ex acts r x c = (* TODO: apply to extractions for any [c]? Could be expensive for Shostak *) if Z.equal c Z.zero then ( - add_eq acts r x; + add_eq ~ex acts r x; true ) else if Z.equal c (Z.extract Z.minus_one 0 (bitwidth r)) then ( - add_eq acts r + add_eq ~ex acts r (Shostak.Bitv.is_mine @@ Bitv.lognot @@ Shostak.Bitv.embed x); true ) else false (* Add the constraint: r = x + c *) - let add_add_const acts r x c = + let add_add_const ~ex acts r x c = let sz = bitwidth r in if Z.equal c Z.zero then ( - add_eq acts r x; + add_eq ~ex acts r x; true ) else if X.is_constant r then ( (* c1 = x + c2 -> x = c1 - c2 *) - add_eq_const acts x Z.(value r - c); + add_eq_const ~ex acts x Z.(value r - c); true ) else if Z.testbit c (sz - 1) then (* Due to the modular nature of arithmetic on bit-vectors, [y = x + c] @@ -719,7 +876,7 @@ end = struct are actually equivalent, so we just pick a normalized order between x and r. *) if X.hash_cmp r x > 0 then ( - acts.acts_add_constraint (bvadd x r (const (bitwidth r) c)); + acts.acts_add_constraint ~ex (bvadd x r (const (bitwidth r) c)); true ) else false @@ -727,16 +884,16 @@ end = struct (* r = x - c -> x = r + c (mod 2^sz) *) let c = Z.neg @@ Z.signed_extract c 0 sz in assert (Z.sign c > 0 && not (Z.testbit c sz)); - acts.acts_add_constraint (bvadd x r (const sz c)); + acts.acts_add_constraint ~ex (bvadd x r (const sz c)); true else false (* Add the constraint: r = x << c *) - let add_shl_const acts r x c = + let add_shl_const ~ex acts r x c = let sz = bitwidth r in match Z.to_int c with - | 0 -> add_eq acts r x + | 0 -> add_eq ~ex acts r x | n when n < sz -> assert (n > 0); let r_bitv = Shostak.Bitv.embed r in @@ -744,32 +901,32 @@ end = struct Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (sz - 1 - n) (Shostak.Bitv.embed x) in - add_eq acts + add_eq ~ex acts (Shostak.Bitv.is_mine @@ Bitv.extract sz n (sz - 1) r_bitv) high_bits; - add_eq_const acts + add_eq_const ~ex acts (Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (n - 1) r_bitv) Z.zero | _ | exception Z.Overflow -> - add_eq_const acts r Z.zero + add_eq_const ~ex acts r Z.zero (* Add the constraint: r = x * c *) - let add_mul_const acts r x c = + let add_mul_const ~ex acts r x c = if Z.equal c Z.zero then ( - add_eq_const acts r Z.zero; + add_eq_const ~ex acts r Z.zero; true ) else if Z.popcount c = 1 then ( let ofs = Z.numbits c - 1 in - add_shl_const acts r x (Z.of_int ofs); + add_shl_const ~ex acts r x (Z.of_int ofs); true ) else false (* Add the constraint: r = x >> c *) - let add_lshr_const acts r x c = + let add_lshr_const ~ex acts r x c = let sz = bitwidth r in match Z.to_int c with - | 0 -> add_eq acts r x + | 0 -> add_eq ~ex acts r x | n when n < sz -> assert (n > 0); let r_bitv = Shostak.Bitv.embed r in @@ -777,14 +934,14 @@ end = struct Shostak.Bitv.is_mine @@ Bitv.extract sz n (sz - 1) (Shostak.Bitv.embed x) in - add_eq acts + add_eq ~ex acts (Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (sz - 1 - n) r_bitv) low_bits; - add_eq_const acts + add_eq_const ~ex acts (Shostak.Bitv.is_mine @@ Bitv.extract sz (sz - n) (sz - 1) r_bitv) Z.zero | _ | exception Z.Overflow -> - add_eq_const acts r Z.zero + add_eq_const ~ex acts r Z.zero (* Ground evaluation rules for binary operators. *) let eval_binop op ty x y = @@ -823,117 +980,123 @@ end = struct evaluated is assumed to be dealt with prior to calling this function. Algebraic rules (e.g. x & x = x) are in [rw_binop_algebraic].*) - let rw_binop_const acts op r x y = + let rw_binop_const ~ex acts op r x y = (* NB: for commutative operators, arguments are sorted, so the second argument can only be constant if the first argument also is constant. *) match op with | Band when X.is_constant x -> - add_and_const acts r y (value x) + add_and_const ~ex acts r y (value x) | Band when X.is_constant y -> - add_and_const acts r x (value y) + add_and_const ~ex acts r x (value y) | Band -> false | Bor when X.is_constant x -> - add_or_const acts r y (value x) + add_or_const ~ex acts r y (value x) | Bor when X.is_constant y -> - add_or_const acts r x (value y) + add_or_const ~ex acts r x (value y) | Bor -> false | Bxor when X.is_constant x -> - add_xor_const acts r y (value x) + add_xor_const ~ex acts r y (value x) | Bxor when X.is_constant y -> - add_xor_const acts r x (value y) + add_xor_const ~ex acts r x (value y) | Bxor when X.is_constant r -> - add_xor_const acts x y (value r) + add_xor_const ~ex acts x y (value r) | Bxor -> false | Badd when X.is_constant x -> - add_add_const acts r y (value x) + add_add_const ~ex acts r y (value x) | Badd when X.is_constant y -> - add_add_const acts r x (value y) + add_add_const ~ex acts r x (value y) | Badd -> false | Bmul when X.is_constant x -> - add_mul_const acts r y (value x) + add_mul_const ~ex acts r y (value x) | Bmul when X.is_constant y -> - add_mul_const acts r x (value y) + add_mul_const ~ex acts r x (value y) | Bmul -> false | Budiv | Burem -> false (* shifts becomes a simple extraction when we know the right-hand side *) | Bshl when X.is_constant y -> - add_shl_const acts r x (value y); + add_shl_const ~ex acts r x (value y); true | Bshl -> false | Blshr when X.is_constant y -> - add_lshr_const acts r x (value y); + add_lshr_const ~ex acts r x (value y); true | Blshr -> false (* Algebraic rewrite rules for binary operators. Rules based on constant simplifications are in [rw_binop_const]. *) - let rw_binop_algebraic acts op r x y = + let rw_binop_algebraic ~ex acts op r x y = match op with (* x & x = x ; x | x = x *) | Band | Bor when X.equal x y -> - add_eq acts r x; true + add_eq ~ex acts r x; true (* r ^ x ^ x = 0 <-> r = 0 *) | Bxor when X.equal x y -> - add_eq_const acts r Z.zero; true + add_eq_const ~ex acts r Z.zero; true | Bxor when X.equal r x -> - add_eq_const acts y Z.zero; true + add_eq_const ~ex acts y Z.zero; true | Bxor when X.equal r y -> - add_eq_const acts x Z.zero; true + add_eq_const ~ex acts x Z.zero; true | Badd when X.equal x y -> (* r = x + x -> r = 2x -> r = x << 1 *) - add_shl_const acts r x Z.one; true + add_shl_const ~ex acts r x Z.one; true | Badd when X.equal r x -> (* x = x + y -> y = 0 *) - add_eq_const acts y Z.zero; true + add_eq_const ~ex acts y Z.zero; true | Badd when X.equal r y -> (* y = x + y -> x = 0 *) - add_eq_const acts x Z.zero; true + add_eq_const ~ex acts x Z.zero; true | _ -> false - let simplify_binop acts op r x y = + let simplify_binop ~ex acts op r x y = if X.is_constant x && X.is_constant y then ( - add_eq acts r @@ + add_eq ~ex acts r @@ eval_binop op (X.type_info r) (value x) (value y); true ) else - rw_binop_const acts op r x y || - rw_binop_algebraic acts op r x y + rw_binop_const ~ex acts op r x y || + rw_binop_algebraic ~ex acts op r x y - let simplify_fun_t acts r = function - | Fbinop (op, x, y) -> simplify_binop acts op r x y + let simplify_fun_t uf acts r = function + | Fbinop (op, x, y) -> + let r, ex_r = Uf.find_r uf r in + let x, ex_x = Uf.find_r uf x in + let y, ex_y = Uf.find_r uf y in + let ex = Explanation.union ex_r (Explanation.union ex_x ex_y) in + simplify_binop ~ex acts op r x y - let simplify_binrel acts op x y = + let simplify_binrel ~ex acts op x y = match op with | Rugt when X.equal x y -> - acts.Rel_utils.acts_add_eq X.top X.bot; + acts.acts_add_eq ~ex X.top X.bot; true | Rule | Rugt -> false - let simplify_rel_t acts = function - | Rbinrel (op, x, y) -> simplify_binrel acts op x y + let simplify_rel_t uf acts = function + | Rbinrel (op, x, y) -> + let x, ex_x = Uf.find_r uf x in + let y, ex_y = Uf.find_r uf y in + simplify_binrel ~ex:(Explanation.union ex_x ex_y) acts op x y - let simplify_repr acts = function - | Cfun (r, f) -> simplify_fun_t acts r f - | Crel r -> simplify_rel_t acts r + let simplify_view uf acts = function + | Cfun (r, f) -> simplify_fun_t uf acts r f + | Crel r -> simplify_rel_t uf acts r - let simplify c acts = - simplify_repr acts c.repr + let simplify uf c acts = + simplify_view uf acts (view c) end -module Constraints = Rel_utils.Constraints_make(Constraint) - let extract_binop = let open Constraint in function | Sy.BVand -> Some bvand @@ -948,18 +1111,33 @@ let extract_binop = | BVlshr -> Some bvlshr | _ -> None -let extract_constraints bcs uf r t = +let extract_term r terms = + if X.is_a_leaf r then SX.add r terms + else terms + +let extract_constraints terms domain int_domain uf r t = match E.term_view t with | { f = Op op; xs = [ x; y ]; _ } -> ( match extract_binop op with | Some mk -> let rx, exx = Uf.find uf x and ry, exy = Uf.find uf y in - Constraints.add - ~ex:(Ex.union exx exy) (mk r rx ry) bcs - | _ -> bcs + let c = mk r rx ry in + let ex = Ex.union exx exy in + let domain = + Bitlist_domains.watch (explained ~ex c) rx @@ + Bitlist_domains.watch (explained ~ex c) ry @@ + domain + in + let int_domain = + Interval_domains.watch (explained ~ex c) rx @@ + Interval_domains.watch (explained ~ex c) ry @@ + int_domain + in + terms, domain, int_domain + | None -> extract_term r terms, domain, int_domain ) - | _ -> bcs + | _ -> extract_term r terms, domain, int_domain let rec mk_eq ex lhs w z = match lhs with @@ -1010,7 +1188,7 @@ let add_eqs = module Any_constraint = struct type t = - | Constraint of Constraint.t Rel_utils.explained + | Constraint of Constraint.t explained | Structural of X.r (** Structural constraint associated with [X.r]. See {!Rel_utils.Bitlist_domains.structural_propagation}. *) @@ -1025,17 +1203,27 @@ module Any_constraint = struct | Constraint c -> 2 * Constraint.hash c.value | Structural r -> 2 * X.hash r + 1 - let propagate constraint_propagate structural_propagation c d = + let propagate constraint_propagate structural_propagation c = Steps.incr CP; match c with | Constraint { value; explanation = ex } -> - constraint_propagate ~ex value d + constraint_propagate ~ex value | Structural r -> - structural_propagation d r + structural_propagation r end module QC = Uqueue.Make(Any_constraint) +let propagate_queue queue constraint_propagate structural_propagation = + try + while true do + Any_constraint.propagate + constraint_propagate + structural_propagation + (QC.pop queue) + done + with QC.Empty -> () + let finite_lower_bound = function | Intervals_intf.Unbounded -> Z.zero | Closed n -> n @@ -1063,8 +1251,9 @@ let finite_upper_bound ~size:sz = function five most-significant bits, denoted [00110???]. Therefore, a bit-vector bl = [0??1???0] can be refined into [00110??0]. *) let constrain_bitlist_from_interval ~size:sz bv int = - let open Bitlist_domains.Ephemeral in - + let + open Rel_utils.HandleNotations(Bitlist_domain)(Bitlist_domains.Ephemeral) + in let inf, inf_ex = Intervals.Int.lower_bound int in let inf = finite_lower_bound inf in let sup, sup_ex = Intervals.Int.upper_bound int in @@ -1094,7 +1283,9 @@ let constrain_bitlist_from_interval ~size:sz bv int = [Bitlist.decrease_upper_bound] on all the constituent intervals of an union; see the documentation of these functions for details. *) let constrain_interval_from_bitlist ~size:sz int bv = - let open Interval_domains.Ephemeral in + let + open Rel_utils.HandleNotations(Interval_domain)(Interval_domains.Ephemeral) + in let ex = Bitlist.explanation bv in (* Handy wrapper around [of_complement] *) let remove ~ex i2 i1 = @@ -1144,126 +1335,252 @@ let constrain_interval_from_bitlist ~size:sz int bv = acc ) !!int !!int -let propagate_bitlist queue touched bcs dom = - let touch_c c = QC.push queue (Constraint c) in - let touch r = - HX.replace touched r (); - QC.push queue (Structural r); - Constraints.iter_parents touch_c r bcs - in - try - while true do - Bitlist_domains.Ephemeral.iter_changed touch dom; - Bitlist_domains.Ephemeral.clear_changed dom; - Any_constraint.propagate - Constraint.propagate_bitlist - Bitlist_domains.Ephemeral.structural_propagation - (QC.pop queue) dom - done - with QC.Empty -> () +let iter_parents a f t = + match Rel_utils.XComparable.Map.find a t with + | cs -> Rel_utils.XComparable.Set.iter f cs + | exception Not_found -> () + +let propagate_bitlist queue vars dom = + let structural_propagation r = + let open Rel_utils.HandleNotations(Bitlist_domain)(Bitlist_domains_uf) in + let get r = !!(Bitlist_domains_uf.entry dom r) in + let update r d = + update ~ex:Explanation.empty (Bitlist_domains_uf.entry dom r) d + in + if X.is_a_leaf r then + iter_parents r (fun p -> + if X.is_a_leaf p then + assert (X.equal r p) + else + update p (CompositeBitlistDomain.map_domain get p) + ) vars + else + let iter_signed sz f { Bitv.value; negated } bl = + let bl = if negated then Bitlist.(extract (lognot bl)) 0 sz else bl in + f value bl + in + ignore @@ List.fold_left (fun (bl, w) { Bitv.bv; sz } -> + (* Extract the bitlist associated with the current component *) + let mid = w - sz in + let bl_tail = Bitlist.extract bl 0 mid in + let bl = Bitlist.extract bl mid (w - mid) in -let propagate_intervals queue touched bcs dom = - let touch_c c = QC.push queue (Constraint c) in - let touch r = - HX.replace touched r (); - QC.push queue (Structural r); - Constraints.iter_parents touch_c r bcs + match bv with + | Bitv.Cte z -> + assert (Z.numbits z <= sz); + (* Nothing to update, but still check for consistency! *) + ignore @@ Bitlist.intersect bl (Bitlist.exact z Ex.empty); + bl_tail, mid + | Other r -> + iter_signed sz update r bl; + (bl_tail, mid) + | Ext (r, r_size, i, j) -> + (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) + assert (i + r_size - j - 1 + w - mid = r_size); + let hi = Bitlist.(extract unknown 0 (r_size - j - 1)) in + let lo = Bitlist.(extract unknown 0 i) in + let bl_hd = Bitlist.((hi lsl (j + 1)) lor (bl lsl i) lor lo) in + iter_signed r_size update r bl_hd; + (bl_tail, mid) + ) ((get r), (bitwidth r)) (Shostak.Bitv.embed r) in - try - while true do - Interval_domains.Ephemeral.iter_changed touch dom; - Interval_domains.Ephemeral.clear_changed dom; - Any_constraint.propagate - Constraint.propagate_interval - Interval_domains.Ephemeral.structural_propagation - (QC.pop queue) dom - done - with QC.Empty -> () - -let propagate_all eqs bcs bdom idom = - (* Call [simplify_pending] first because it can remove constraints from the - pending set. *) - let eqs, bcs = Constraints.simplify_pending eqs bcs in + propagate_queue + queue + (Propagator.propagate_bitlist dom) + structural_propagation + +let propagate_intervals queue vars dom = + let structural_propagation r = + let open Rel_utils.HandleNotations(Interval_domain)(Interval_domains_uf) in + let get r = !!(Interval_domains_uf.entry dom r) in + let update r d = + update ~ex:Explanation.empty (Interval_domains_uf.entry dom r) d + in + if X.is_a_leaf r then + iter_parents r (fun p -> + if X.is_a_leaf p then + assert (X.equal r p) + else + update p (CompositeIntervalDomain.map_domain get p) + ) vars + else + let iter_signed f { Bitv.value; negated } sz int = + f value (if negated then Interval_domain.lognot sz int else int) + in + let int = get r in + let width = bitwidth r in + let j = + List.fold_left (fun j { Bitv.bv; sz } -> + (* sz = j - i + 1 => i = j - sz + 1 *) + let int = Intervals.Int.extract int ~ofs:(j - sz + 1) ~len:sz in + + begin match bv with + | Bitv.Cte z -> + (* Nothing to update, but still check for consistency *) + ignore @@ + Interval_domain.intersect int (Interval_domain.constant z) + | Other s -> iter_signed update s sz int + | Ext (r, r_size, i, j) -> + (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) + assert (i + r_size - j - 1 + sz = r_size); + let lo = Interval_domain.unknown (Tbitv i) in + let int = Intervals.Int.scale Z.(~$1 lsl i) int in + let hi = Interval_domain.unknown (Tbitv (r_size - j - 1)) in + let hi = + Intervals.Int.scale Z.(~$1 lsl Stdlib.(i + sz)) hi + in + iter_signed update r r_size Intervals.Int.(add hi (add int lo)) + end; + + (j - sz) + ) (width - 1) (Shostak.Bitv.embed r) + in + assert (j = -1) + in + propagate_queue + queue + (Propagator.propagate_interval dom) + structural_propagation + +module HC = Hashtbl.Make(Constraint) + +let simplify_all uf eqs touched (dom, idom) = + let eqs = ref eqs in + let to_add = HC.create 17 in + let simplify c c_ex (dom, idom) = + let acts_add_lit_view ~ex l = + eqs := (l, Explanation.union ex c_ex) :: !eqs + in + let acts_add_eq ~ex u v = + acts_add_lit_view ~ex (Uf.LX.mkv_eq u v) + in + let acts_add_constraint ~ex c = + let c = explained ~ex:(Explanation.union ex c_ex) c in + HC.replace to_add c.value c.explanation + in + let acts = + { acts_add_lit_view + ; acts_add_eq + ; acts_add_constraint } in + if Propagator.simplify uf c acts then + let c = explained ~ex:c_ex c in + (Bitlist_domains.unwatch c dom, Interval_domains.unwatch c idom) + else + (dom, idom) + in + let dom, idom = HC.fold simplify touched (dom, idom) in + !eqs, + HC.fold (fun c c_ex (dom, idom) -> + let c = explained ~ex:c_ex c in + Constraint.fold_args (fun r (dom, idom) -> + let r, _ = Uf.find_r uf r in + Bitlist_domains.watch c r dom, + Interval_domains.watch c r idom + ) c.value (dom, idom) + ) to_add (dom, idom) + +let rec propagate_all uf eqs bdom idom = (* Optimization to avoid unnecessary allocations *) - let do_all = Constraints.has_pending bcs in - let do_bitlist = do_all || Bitlist_domains.has_changed bdom in - let do_intervals = do_all || Interval_domains.has_changed idom in + let do_bitlist = Bitlist_domains.needs_propagation bdom in + let do_intervals = Interval_domains.needs_propagation idom in let do_any = do_bitlist || do_intervals in if do_any then - let queue = QC.create 17 in - let touch_pending queue = - Constraints.iter_pending (fun c -> QC.push queue (Constraint c)) bcs + let shostak_candidates = HX.create 17 in + let seen_constraints = HC.create 17 in + let bitlist_queue = QC.create 17 in + let interval_queue = QC.create 17 in + let touch_c queue c = + HC.replace seen_constraints c.value c.explanation; + QC.push queue (Constraint c) + in + let touch tbl queue r = + HX.replace tbl r (); + QC.push queue (Structural r) in let bitlist_changed = HX.create 17 in - let touched = HX.create 17 in - let bdom = Bitlist_domains.edit bdom in - let idom = Interval_domains.edit idom in + let interval_changed = HX.create 17 in + let bitlist_events = + { Rel_utils.evt_atomic_change = touch bitlist_changed bitlist_queue + ; evt_composite_change = touch bitlist_changed bitlist_queue + ; evt_watch_trigger = touch_c bitlist_queue } + and interval_events = + { Rel_utils.evt_atomic_change = touch interval_changed interval_queue + ; evt_composite_change = touch interval_changed interval_queue + ; evt_watch_trigger = touch_c interval_queue } + in + let bvars = Bitlist_domains.parents bdom in + let ivars = Interval_domains.parents idom in + + let bdom = Bitlist_domains.edit ~events:bitlist_events bdom in + let idom = Interval_domains.edit ~events:interval_events idom in + + let uf_bdom = Bitlist_domains_uf.wrap uf bdom in + let uf_idom = Interval_domains_uf.wrap uf idom in (* First, we propagate the pending constraints to both domains. Changes in the bitlist domain are used to shrink the interval domains. *) - touch_pending queue; - propagate_bitlist queue touched bcs bdom; - assert (QC.is_empty queue); + propagate_bitlist bitlist_queue bvars uf_bdom; + assert (QC.is_empty bitlist_queue); - touch_pending queue; HX.iter (fun r () -> - HX.replace bitlist_changed r (); - let sz = bitwidth r in - constrain_interval_from_bitlist ~size:sz - Interval_domains.Ephemeral.(handle idom r) - Bitlist_domains.Ephemeral.(!!(handle bdom r)) - ) touched; - HX.clear touched; - propagate_intervals queue touched bcs idom; - assert (QC.is_empty queue); + HX.replace shostak_candidates r (); + constrain_interval_from_bitlist ~size:(bitwidth r) + Interval_domains.Ephemeral.(entry idom r) + Bitlist_domains.Ephemeral.(Entry.domain (entry bdom r)) + ) bitlist_changed; + HX.clear bitlist_changed; + propagate_intervals interval_queue ivars uf_idom; + assert (QC.is_empty interval_queue); (* Now the interval domain is stable, but the new intervals may have an impact on the bitlist domains, so we must shrink them again when applicable. We repeat until a fixpoint is reached. *) - let bcs = Constraints.clear_pending bcs in - while HX.length touched > 0 do + while HX.length interval_changed > 0 do HX.iter (fun r () -> - let sz = bitwidth r in - constrain_bitlist_from_interval ~size:sz - Bitlist_domains.Ephemeral.(handle bdom r) - Interval_domains.Ephemeral.(!!(handle idom r)) - ) touched; - HX.clear touched; - propagate_bitlist queue touched bcs bdom; - assert (QC.is_empty queue); + constrain_bitlist_from_interval ~size:(bitwidth r) + Bitlist_domains.Ephemeral.(entry bdom r) + Interval_domains.Ephemeral.(Entry.domain (entry idom r)) + ) interval_changed; + HX.clear interval_changed; + propagate_bitlist bitlist_queue bvars uf_bdom; + assert (QC.is_empty bitlist_queue); HX.iter (fun r () -> - let sz = bitwidth r in - HX.replace bitlist_changed r (); - constrain_interval_from_bitlist ~size:sz - Interval_domains.Ephemeral.(handle idom r) - Bitlist_domains.Ephemeral.(!!(handle bdom r)) - ) touched; - HX.clear touched; - propagate_intervals queue touched bcs idom; - assert (QC.is_empty queue); + HX.replace shostak_candidates r (); + constrain_interval_from_bitlist ~size:(bitwidth r) + Interval_domains.Ephemeral.(entry idom r) + Bitlist_domains.Ephemeral.(Entry.domain (entry bdom r)) + ) bitlist_changed; + HX.clear bitlist_changed; + propagate_intervals interval_queue ivars uf_idom; + assert (QC.is_empty interval_queue); done; let eqs = HX.fold (fun r () acc -> - let d = Bitlist_domains.Ephemeral.(!!(handle bdom r)) in - let sz = bitwidth r in - add_eqs acc (Shostak.Bitv.embed r) sz d - ) bitlist_changed eqs + let d = Bitlist_domains.Ephemeral.(Entry.domain (entry bdom r)) in + add_eqs acc (Shostak.Bitv.embed r) (bitwidth r) d + ) shostak_candidates eqs + in + + let bdom, idom = + Bitlist_domains.snapshot bdom, Interval_domains.snapshot idom in + let eqs, (bdom, idom) = simplify_all uf eqs seen_constraints (bdom, idom) in - eqs, bcs, Bitlist_domains.snapshot bdom, Interval_domains.snapshot idom + (* Propagate again in case constraints were simplified. *) + propagate_all uf eqs bdom idom else - eqs, bcs, bdom, idom + eqs, bdom, idom type t = { delayed : Rel_utils.Delayed.t - ; constraints : Constraints.t + ; terms : SX.t ; size_splits : Q.t } let empty uf = { delayed = Rel_utils.Delayed.create ~is_ready:X.is_constant dispatch - ; constraints = Constraints.empty + ; terms = SX.empty ; size_splits = Q.one }, Uf.GlobalDomains.add (module Bitlist_domains) Bitlist_domains.empty @@ Uf.GlobalDomains.add (module Interval_domains) Interval_domains.empty @@ @@ -1276,54 +1593,56 @@ let assume env uf la = Uf.GlobalDomains.find (module Interval_domains) ds in let delayed, result = Rel_utils.Delayed.assume env.delayed uf la in - let (domain, int_domain, constraints, eqs, size_splits) = + let (domain, int_domain, eqs, size_splits) = try - let (constraints, eqs, size_splits) = - List.fold_left (fun (bcs, eqs, ss) (a, _root, ex, orig) -> - let ss = - match orig with - | Th_util.CS (Th_bitv, n) -> Q.(ss * n) - | _ -> ss - in - let is_1bit r = - match X.type_info r with - | Tbitv 1 -> true - | _ -> false - in - match a, orig with - | L.Eq (rr, nrr), Subst when is_bv_r rr -> - let bcs = Constraints.subst ~ex rr nrr bcs in - (bcs, eqs, ss) - | Builtin (is_true, BVULE, [x; y]), _ -> - let x, exx = Uf.find_r uf x in - let y, exy = Uf.find_r uf y in - let ex = Ex.union ex @@ Ex.union exx exy in - let c = - if is_true then - Constraint.bvule x y - else - Constraint.bvugt x y - in - let bcs = Constraints.add ~ex c bcs in - (bcs, eqs, ss) - | L.Distinct (false, [rr; nrr]), _ when is_1bit rr -> - (* We don't (yet) support [distinct] in general, but we must - support it for case splits to avoid looping. - - We are a bit more general and support it for 1-bit vectors, for - which `distinct` can be expressed using `bvnot`. *) - let not_nrr = - Shostak.Bitv.is_mine (Bitv.lognot (Shostak.Bitv.embed nrr)) - in - (bcs, (Uf.LX.mkv_eq rr not_nrr, ex) :: eqs, ss) - | _ -> (bcs, eqs, ss) + let (domain, int_domain, eqs, size_splits) = + List.fold_left + (fun (domain, int_domain, eqs, ss) (a, _root, ex, orig) -> + let ss = + match orig with + | Th_util.CS (Th_bitv, n) -> Q.(ss * n) + | _ -> ss + in + let is_1bit r = + match X.type_info r with + | Tbitv 1 -> true + | _ -> false + in + match a, orig with + | L.Builtin (is_true, BVULE, [x; y]), _ -> + let x, exx = Uf.find_r uf x in + let y, exy = Uf.find_r uf y in + let ex = Ex.union ex @@ Ex.union exx exy in + let c = + if is_true then + Constraint.bvule x y + else + Constraint.bvugt x y + in + (* Only watch comparisons on the interval domain since we don't + propagate them in the bitlist domain. . *) + let int_domain = + Interval_domains.watch (explained ~ex c) x @@ + Interval_domains.watch (explained ~ex c) y @@ + int_domain + in + (domain, int_domain, eqs, ss) + | L.Distinct (false, [rr; nrr]), _ when is_1bit rr -> + (* We don't (yet) support [distinct] in general, but we must + support it for case splits to avoid looping. + + We are a bit more general and support it for 1-bit vectors, + for which `distinct` can be expressed using `bvnot`. *) + let not_nrr = + Shostak.Bitv.is_mine (Bitv.lognot (Shostak.Bitv.embed nrr)) + in + (domain, int_domain, (Uf.LX.mkv_eq rr not_nrr, ex) :: eqs, ss) + | _ -> (domain, int_domain, eqs, ss) ) - (env.constraints, [], env.size_splits) + (domain, int_domain, [], env.size_splits) la in - let eqs, constraints, domain, int_domain = - propagate_all eqs constraints domain int_domain - in + let eqs, domain, int_domain = propagate_all uf eqs domain int_domain in if Options.get_debug_bitv () && not (Lists.is_empty eqs) then ( Printer.print_dbg ~module_name:"Bitv_rel" ~function_name:"assume" @@ -1331,11 +1650,8 @@ let assume env uf la = Printer.print_dbg ~module_name:"Bitv_rel" ~function_name:"assume" "interval domain: @[%a@]" Interval_domains.pp int_domain; - Printer.print_dbg - ~module_name:"Bitv_rel" ~function_name:"assume" - "bitlist constraints: @[%a@]" Constraints.pp constraints; ); - (domain, int_domain, constraints, eqs, size_splits) + (domain, int_domain, eqs, size_splits) with Bitlist.Inconsistent ex | Interval_domain.Inconsistent ex -> raise @@ Ex.Inconsistent (ex, Uf.cl_extract uf) in @@ -1345,7 +1661,7 @@ let assume env uf la = let result = { result with assume = List.rev_append assume result.assume } in - { delayed ; constraints ; size_splits }, + { delayed ; size_splits ; terms = env.terms }, Uf.GlobalDomains.add (module Bitlist_domains) domain @@ Uf.GlobalDomains.add (module Interval_domains) int_domain ds, result @@ -1365,7 +1681,9 @@ let case_split env uf ~for_model = constrained variables, all the remaining variables. [nunk] is the number of unknown bits. *) - let f_acc r bl acc = + let f_acc r acc = + let r, _ = Uf.find_r uf r in + let bl = Bitlist_domains.get r domain in let nunk = Z.popcount (Bitlist.unknown_bits bl) in if nunk = 0 then acc @@ -1376,28 +1694,21 @@ let case_split env uf ~for_model = Some (nunk', SX.add r xs) | _ -> Some (nunk, SX.singleton r) in - let f_acc' r acc = - let r, _ = Uf.find_r uf r in - List.fold_left (fun acc { Bitv.bv; _ } -> - match bv with - | Bitv.Cte _ -> acc - | Other r | Ext (r, _, _, _) -> - let bl = Bitlist_domains.get r.value domain in - f_acc r.value bl acc - ) acc (Shostak.Bitv.embed r) - in let _, candidates = - match Constraints.fold_args f_acc' env.constraints None with + match SX.fold f_acc env.terms None with | Some (nunk, xs) -> nunk, xs - | _ -> - match Bitlist_domains.fold_leaves f_acc domain None with - | Some (nunk, xs) -> nunk, xs - | None -> 0, SX.empty + | None -> 0, SX.empty in (* For now, just pick a value for the most significant bit. *) match SX.choose candidates with | r -> - let bl = Bitlist_domains.get r domain in + let rr, _ = Uf.find_r uf r in + let bl = Bitlist_domains.get rr domain in + let r = + let es = Uf.rclass_of uf r in + try Uf.make uf (Expr.Set.choose es) + with Not_found -> r + in let w = bitwidth r in let unknown = Z.extract (Bitlist.unknown_bits bl) 0 w in let bitidx = Z.numbits unknown - 1 in @@ -1415,15 +1726,22 @@ let case_split env uf ~for_model = | exception Not_found -> [] let add env uf r t = + let ds = Uf.domains uf in let delayed, eqs = Rel_utils.Delayed.add env.delayed uf r t in - let env, eqs = + let env, ds, eqs = if is_bv_r r then - let constraints = extract_constraints env.constraints uf r t in - { env with constraints }, eqs + let dom = Uf.GlobalDomains.find (module Bitlist_domains) ds in + let idom = Uf.GlobalDomains.find (module Interval_domains) ds in + let terms, dom, idom = extract_constraints env.terms dom idom uf r t in + { env with terms }, + Uf.GlobalDomains.add (module Bitlist_domains) dom @@ + Uf.GlobalDomains.add (module Interval_domains) idom @@ + ds, + eqs else - env, eqs + env, ds, eqs in - { env with delayed }, Uf.domains uf, eqs + { env with delayed }, ds, eqs let optimizing_objective _env _uf _o = None diff --git a/src/lib/reasoners/rel_utils.ml b/src/lib/reasoners/rel_utils.ml index 854077ff20..fd1267e948 100644 --- a/src/lib/reasoners/rel_utils.ml +++ b/src/lib/reasoners/rel_utils.ml @@ -217,6 +217,136 @@ end = struct MX.iter (fun r -> OMap.iter (fun op -> Expr.Set.iter (f r op))) t.used_by end +module type Map_like = sig + (** Minimal signature for a persistent map type, used by [EphemeralMap]. *) + + type 'a t + + type key + + val find : key -> 'a t -> 'a + + val add : key -> 'a -> 'a t -> 'a t +end + +module type Hashtbl_like = sig + (** Minimal signature for an imperative map type, used by [EphemeralMap]. *) + + type 'a t + + type key + + val create : int -> 'a t + + val find : 'a t -> key -> 'a + + val replace : 'a t -> key -> 'a -> unit + + val fold : (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b +end + +module EphemeralMap + (MX : Map_like) + (HX : Hashtbl_like with type key = MX.key) + : sig + (** This module implements an ephemeral (mutable) interface for efficient + (repeated) lookup and update to the underlying persistent map, as well + as conversion functions between persistent and ephemeral maps. *) + + type 'a t + (** The type of ephemeral maps with values of type ['a]. *) + + type key = MX.key + (** The type of keys in the ephemeral map. *) + + module Entry : sig + (** Entries associate a (mutable) content to keys in the map. *) + + type 'a t + (** The type of entries with values ['a]. *) + + val content : 'a t -> 'a + (** [content e] is the content associated with [key e] in the map. *) + + val set_content : 'a t -> 'a -> unit + (** [set_content e v] sets the content of entry [e] to [v]. This + overwrites any pre-existing content associated with [e]. *) + end + + val entry : 'a t -> key -> 'a Entry.t + (** [entry t k] returns an entry associated with key [k] in the map. + + Each key is associated with a single entry: calling [entry t k] several + times will always return the same entry. *) + + val edit : default:(key -> 'a) -> 'a MX.t -> 'a t + (** [edit ~default t] returns an ephemeral copy of [t] for edition. + + The [default] argument is used to compute a default value for missing + keys. *) + + val snapshot : 'a t -> 'a MX.t + (** [snapshot t] computes a persistent snapshot of the ephemeral map [t], + applying all the changes made using [set_content]. Entries that were + never written to using [set_content] are unchanged, even if they contain + a [default] value due to not present in the map when it was [edit]ed. *) + end = +struct + type key = MX.key + + module Entry = struct + type 'a t = + { key : MX.key + ; mutable value : 'a + ; mutable dirty : bool + ; dirty_cache : 'a t HX.t } + + let content { value; _ } = value + + let set_dirty handle = + if not handle.dirty then ( + handle.dirty <- true; + HX.replace handle.dirty_cache handle.key handle + ) + + let set_content handle value = + set_dirty handle; + handle.value <- value + end + + type 'a t = + { values : 'a MX.t + ; handles : 'a Entry.t HX.t + ; dirty_cache : 'a Entry.t HX.t + ; default : MX.key -> 'a } + + let entry t r = + try HX.find t.handles r with Not_found -> + let handle = + { Entry.key = r + ; value = (try MX.find r t.values with Not_found -> t.default r) + ; dirty = false + ; dirty_cache = t.dirty_cache } + in + HX.replace t.handles r handle; + handle + + let edit ~default t = + let size = 17 in + { values = t + ; handles = HX.create size + ; dirty_cache = HX.create size + ; default } + + let snapshot t = + let persistent = t.values in + HX.fold (fun repr handle t -> + (* NB: we are in the [dirty_cache] so we know that the domain has been + updated. *) + MX.add repr (Entry.content handle) t + ) t.dirty_cache persistent +end + module type Domain = sig type t (** The type of domains for a single value. @@ -239,6 +369,12 @@ module type Domain = sig val filter_ty : Ty.t -> bool (** Filter for the types of values this domain can be attached to. *) + type constant + (** The type of constant values. *) + + val constant : constant -> t + (** [constant c] returns the singleton domain {m \{ c \}}. *) + val unknown : Ty.t -> t (** [unknown ty] returns a full domain for values of type [t]. @@ -256,634 +392,865 @@ module type Domain = sig @raise Inconsistent if [d1] and [d2] are not compatible (the intersection would be empty). *) +end +module type OffsetDomain = sig + (** This module represents domains to which (constant) offsets can be added or + removed. It extends the [Domain] signature. *) - val fold_leaves : (X.r -> t -> 'a -> 'a) -> X.r -> t -> 'a -> 'a - (** [fold_leaves f r t acc] folds [f] over the leaves of [r], deconstructing - the domain [t] according to the structure of [r]. + include Domain + + val add_offset : t -> constant -> t + (** [add_offset ofs d] adds the offset [ofs] to domain [d]. *) + + val sub_offset : t -> constant -> t + (** [sub_offset ofs d] removes the offset [ofs] from domain [d]. *) +end + +module type EphemeralDomainMap = sig + (** This module provides a signature for ephemeral domain maps: imperative + mappings from some key type to a domain type. *) + + type t + (** The type of ephemeral domain maps, i.e. an imperative structure mapping + keys to their current domain. *) - It is assumed that [t] already contains any justification required for - it to apply to [r]. + type key + (** The type of keys in the ephemeral map. *) - @raise Inconsistent if [r] cannot possibly be in the domain of [t]. *) + type domain + (** The type of domains. *) - val map_leaves : (X.r -> t) -> X.r -> t - (** [map_leaves f r] is the "inverse" of [fold_leaves] in the sense that - it rebuilds a domain for [r] by using [f] to access the domain for each - of [r]'s leaves. *) + module Entry : sig + type t + (** A mutable entry associated with a given key. Can be used to access and + update the associated domain imperatively. A single (physical) entry is + associated with a given key. *) + + val domain : t -> domain + (** Return the domain associated with this entry. *) + + val set_domain : t -> domain -> unit + (** Intersect the domain associated with this entry and the provided + [domain]. The explanation [ex] justifies that the [domain] applies to + the entry's key. + + @raise Domain.Inconsistent if the intersection is empty. *) + end + + val entry : t -> key -> Entry.t + (** [entry t k] returns the [handle] associated with [k]. + + There is a unique entry associated with each key [k] that is created + on-the-fly when [handle t k] is called for the first time. + + The domain associated with the entry is initialized from the underlying + persistent domain the first time it is accessed, and updated with + [update]. *) end -module type Domains = sig - (** Extended signature for global domains. *) +module type OrderedType = sig + (** Module signature for an ordered type equipped with a [compare] function. + + This is similar to [Set.OrderedType] and [Map.OrderedType], but includes + pre-built [Set] and [Map] modules. *) + + type t - include Uf.GlobalDomain + val pp : t Fmt.t - type elt - (** The type of domains contained in the map. Each domain of type [elt] - applies to a single semantic value. *) + val compare : t -> t -> int - val get : X.r -> t -> elt - (** [get r t] returns the domain currently associated with [r] in [t]. *) + module Set : Set.S with type elt = t - val fold_leaves : (X.r -> elt -> 'a -> 'a) -> t -> 'a -> 'a - (** [fold f t acc] folds [f] over all the domains in [t] that are associated - with leaves. *) + module Map : Map.S with type key = t +end - val has_changed : t -> bool - (** Returns [true] if any element is marked as changed. This can be used to - avoid unnecessary calls to [edit]. +module type ComparableType = sig + (** Module signature combining [OrderedType] and [Hashtbl.HashedType]. - Elements are marked as changed when their domain shrinks due to a call to - [subst], or through the ephemeral API. Elements can be unmarked by - [clear_changed] in the ephemeral API. *) + This includes a pre-built [Table] module that implements the [Hashtbl.S] + signature. *) - module Ephemeral : sig - type handle - (** A mutable handle to the domain associated with a semantic value. Can be - used to access and update the domain. *) + include OrderedType - val (!!) : handle -> elt - (** Return the domain associated with the [handle]. *) + val equal : t -> t -> bool - val update : ex:Explanation.t -> handle -> elt -> unit - (** Intersect the domain associated with the [handle] with the provided - [domain]. The explanation [ex] justifies that the [domain] applies to - the [handle]'s representative. + val hash : t -> int - If this changes the domain associated with the handle, the handle is - marked as changed. + module Table : Hashtbl.S with type key = t +end - @raise Domain.Inconsistent if the intersection is empty. *) +module DomainMap + (X : ComparableType) + (D : Domain) + : sig + (** A persistent map to a domain type, with an ephemeral interface. *) type t - (** Mutable mappings from semantic values to [domain]s. *) + (** The type of domain maps. *) - val handle : t -> X.r -> handle - (** [handle t r] returns the [handle] associated with [r]. + val pp : t Fmt.t + (** Pretty-printer for domain maps. *) - There is a unique handle associated with each semantic value [r] that is - created on-the-fly when [handle t r] is called for the first time. + val empty : t + (** The empty domain map. *) - The domain associated with the handle is initialized from the - underlying persistent domain the first time it is accessed, and updated - with [update]. *) + type key = X.t + (** The type of keys in the map. *) - val structural_propagation : t -> X.r -> unit - (** Perform structural propagation for the given representative. + type domain = D.t + (** The type of per-variable domains. *) - More precisely, if [r] is a leaf, the domain of [r] is propagated to any - semantic value that contains [r] as a leaf according to the structure of - that semantic value (using [Domain.map_leaves]); if [r] is not a leaf, - its domain is propagated to any of the leaves it contains. + val find : key -> t -> domain + (** Find the domain associatd with the given key. - We only perform *forward* structural propagation: if structural - propagation causes a domain of a leaf or parent to be changed, then we - only mark that leaf or parent as changed. + @raise Not_found if there is no domain associated with the key. *) - @raise Inconsistent if an inconsistency if detected during structural - propagation. *) + val add : key -> domain -> t -> t + (** Adds a domain associated with a given key. - val iter_changed : (X.r -> unit) -> t -> unit - (** Iterate over all the semantic values that have been marked as changed - since the last call to [clear_changed]. Values are marked as changed by - [update] whenever their domain shrinks. + {b Warning}: If the key is not constant, [add] updates the domain + associated with the variable part of the key, and hence influences the + domains of other keys that have the same variable part as this key. *) - {b Warning}: The behavior is not specified if the ephemeral domain is - modified during iteration, such as by calling [update] or - [structural_propagation]. *) + val remove : key -> t -> t + (** Removes the domain associated with a single variable. This will + effectively remove the domains associated with all keys that have the + same variable part. *) - val clear_changed : t -> unit - (** Remove the [changed] flag from all values. *) - end + val needs_propagation : t -> bool + (** Returns [true] if the domain map needs propagation, i.e. if the domain + associated with any variable has changed. *) - val edit : t -> Ephemeral.t - (** [edit d] returns an ephemeral version of the domain that can be used for - editing. *) + module Ephemeral : EphemeralDomainMap + with type key = key and type domain = domain - val snapshot : Ephemeral.t -> t - (** [snapshot e] returns a persistent version of [e]. *) -end + val edit : + notify:(key -> unit) -> default:(key -> domain) -> t -> Ephemeral.t + (** Create an ephemeral domain map from the current domain map. -module Domains_make(Domain : Domain) : Domains with type elt = Domain.t = -struct - type elt = Domain.t + [notify] will be called whenever the domain associated with a variable + changes. *) - exception Inconsistent = Domain.Inconsistent + val snapshot : Ephemeral.t -> t + (** Convert back a (modified) ephemeral domain map into a persistent one. *) + end - type t = { - domains : Domain.t MX.t ; - (** Map from tracked representatives to their domain *) += +struct + module MX = X.Map + module SX = X.Set + module HX = X.Table + module EX = EphemeralMap(MX)(HX) - changed : SX.t ; - (** Representatives whose domain has changed since the last flush *) + type t = + { domains : D.t MX.t + ; changed : SX.t } - leaves_map : SX.t MX.t ; - (** Map from leaves to the *tracked* representatives that contains them *) - } + type key = X.t - type _ Uf.id += Id : t Uf.id + type domain = D.t let pp ppf t = - Fmt.(iter_bindings ~sep:semi MX.iter - (box @@ pair ~sep:(any " ->@ ") X.print Domain.pp) - ) + Fmt.iter_bindings ~sep:Fmt.semi MX.iter + (Fmt.box @@ Fmt.pair ~sep:(Fmt.any " ->@ ") X.pp D.pp) ppf t.domains let empty = - { domains = MX.empty ; changed = SX.empty ; leaves_map = MX.empty } - - let filter_ty = Domain.filter_ty - - let r_add r leaves_map = - List.fold_left (fun leaves_map leaf -> - MX.update leaf (function - | Some parents -> Some (SX.add r parents) - | None -> Some (SX.singleton r) - ) leaves_map - ) leaves_map (X.leaves r) - - let create_domain r = - Domain.map_leaves (fun r -> - Domain.unknown (X.type_info r) - ) r - - let add r t = - if MX.mem r t.domains then t else - (* Note: we do not need to mark [r] as needing propagation, because no - constraints applied to it yet. Any constraint that apply to [r] will - already be marked as pending due to being newly added. *) - let d = create_domain r in - let domains = MX.add r d t.domains in - let leaves_map = r_add r t.leaves_map in - { t with domains; leaves_map } - - let r_remove r leaves_map = - List.fold_left (fun leaves_map leaf -> - MX.update leaf (function - | Some parents -> - let parents = SX.remove r parents in - if SX.is_empty parents then None else Some parents - | None -> None - ) leaves_map - ) leaves_map (X.leaves r) - - let remove r t = - let changed = SX.remove r t.changed in - let domains = MX.remove r t.domains in - let leaves_map = r_remove r t.leaves_map in - { changed; domains; leaves_map } - - let get r t = - (* We need to catch [Not_found] because of fresh terms that can be added - by the solver and for which we don't call [add]. Note that in this - case, only structural constraints can apply to [r]. *) - try MX.find r t.domains with Not_found -> create_domain r - - (* Marked as unsafe because we trust the [changed] flag from the caller. *) - let unsafe_update ?(changed = true) r d t = - match MX.find r t.domains with - | od -> - (* Both domains are already valid for [r], we can intersect them - without additional justifications. *) - let d = Domain.intersect od d in - if Domain.equal od d then - t - else - let domains = MX.add r d t.domains in - let changed = if changed then SX.add r t.changed else t.changed in - { t with domains; changed } - | exception Not_found -> - (* We need to catch [Not_found] because of fresh terms that can be added - by the solver and for which we don't call [add]. *) - let d = Domain.intersect d (create_domain r) in - let domains = MX.add r d t.domains in - let changed = if changed then SX.add r t.changed else t.changed in - let leaves_map = r_add r t.leaves_map in - { domains; changed; leaves_map } - - let fold_leaves f t acc = - MX.fold (fun r _ acc -> - f r (get r t) acc - ) t.leaves_map acc + { domains = MX.empty + ; changed = SX.empty } - let subst ~ex rr nrr t = - (* Need to add [ex] to be a valid domain for [nrr] *) - let d = Domain.add_explanation ~ex (get rr t) in - let changed = SX.mem rr t.changed in - let t = remove rr t in - match MX.find nrr t.domains with - | nd -> - (* If there is an existing domain for [nrr], there might be - constraints applying to [nrr] prior to the substitution, and the - constraints that used to apply to [rr] will also apply to [nrr] - after the substitution. - - We need to notify changed to either of these constraints, so we - must notify if the domain is different from *either* the old - domain of [rr] or the old domain of [nrr]. *) - let nnd = Domain.intersect d nd in - let nrr_changed = not (Domain.equal nnd nd) in - let rr_changed = not (Domain.equal nnd d) in - let domains = - if nrr_changed then MX.add nrr nnd t.domains else t.domains - in - let changed = changed || rr_changed || nrr_changed in - let changed = - if changed then SX.add nrr t.changed else t.changed - in - { t with domains; changed } - | exception Not_found -> - (* If there is no existing domain for [nr], there were no - constraints applying to [nr] prior to the substitution. - - The only constraints that need to be notified are those that - were applying to [r], and they only need to be notified if the - new domain is different from the old domain of [r]. *) - let default = create_domain nrr in - let nd = Domain.intersect d default in - let rr_changed = not (Domain.equal nd d) in - (* Make sure to not add more constraints than necessary for the - representative domain. *) - let nd = if Domain.equal nd default then default else nd in - let domains = MX.add nrr nd t.domains in - let leaves_map = r_add nrr t.leaves_map in - let changed = changed || rr_changed in - let changed = - if changed then SX.add nrr t.changed else t.changed - in - { domains; changed; leaves_map } + let find x t = MX.find x t.domains + + let remove x t = + { domains = MX.remove x t.domains + ; changed = SX.remove x t.changed } - let has_changed t = - not @@ SX.is_empty t.changed + let add x d t = { t with domains = MX.add x d t.domains } + + let needs_propagation t = not (SX.is_empty t.changed) module Ephemeral = struct - type handle = - { repr : X.r - ; mutable domain : Domain.t - ; mutable dirty : bool - ; dirty_cache : handle HX.t - ; mutable changed : bool - ; changed_set : handle HX.t - } + type nonrec key = key + type nonrec domain = domain + + module Entry = struct + type t = + { entry : domain EX.Entry.t + ; key : key + ; notify : X.t -> unit } + + let domain { entry ; _ } = EX.Entry.content entry + + let set_domain { entry ; notify ; key } dom = + EX.Entry.set_content entry @@ dom; + notify key + end + + type t = + { domains : domain EX.t + ; notify : X.t -> unit } + + let entry t x = + { Entry.entry = EX.entry t.domains x + ; key = x + ; notify = t.notify } + end - let (!!) handle = handle.domain + let edit ~notify ~default t = + SX.iter notify t .changed; - let set_dirty handle = - if not handle.dirty then ( - handle.dirty <- true; - HX.replace handle.dirty_cache handle.repr handle - ) + { Ephemeral.domains = EX.edit ~default t.domains + ; notify } - let set_changed handle = - if not handle.changed then ( - set_dirty handle; - handle.changed <- true; - HX.replace handle.changed_set handle.repr handle - ) + let snapshot t = + { domains = EX.snapshot t.Ephemeral.domains + ; changed = SX.empty } +end - let update ~ex handle domain = - let domain = Domain.add_explanation ~ex domain in - let domain = Domain.intersect handle.domain domain in - if not (Domain.equal domain handle.domain) then ( - set_changed handle; - handle.domain <- domain - ) - type nonrec t = - { persistent : t - ; handles : handle HX.t - ; dirty_cache : handle HX.t - ; changed_set : handle HX.t } - - let handle t r = - try HX.find t.handles r with Not_found -> - let handle = - { repr = r - ; domain = get r t.persistent - ; dirty = false - ; dirty_cache = t.dirty_cache - ; changed = false - ; changed_set = t.changed_set } - in - HX.add t.handles r handle; - handle - - let structural_propagation t r = - (* Structural propagation is always correct and does not require - explanations because it follows the structure of the semantic value - itself. *) - let get r = !!(handle t r) in - let update r d = update ~ex:Explanation.empty (handle t r) d in - if X.is_a_leaf r then - match MX.find r t.persistent.leaves_map with - | parents -> - SX.iter (fun parent -> - if X.is_a_leaf parent then - assert (X.equal r parent) - else - update parent (Domain.map_leaves get parent) - ) parents - | exception Not_found -> () - else - Domain.fold_leaves (fun r d () -> update r d) r (get r) () +module BinRel(X : OrderedType)(W : OrderedType) : sig + (** This module provides a thin abstraction to keep track of binary relations + between values of two different types. *) - let iter_changed f t = HX.iter (fun r _ -> f r) t.changed_set + type t + (** The type of binary relations between [X.t] and [W.t]. *) - let clear_changed t = - HX.iter (fun _ h -> h.changed <- false) t.changed_set; - HX.clear t.changed_set - end + val empty : t + (** The empty relation. *) - let edit t = - let size = 17 in - let ephemeral = - { Ephemeral.persistent = { t with changed = SX.empty } - ; handles = HX.create size - ; dirty_cache = HX.create size - ; changed_set = HX.create size } + val add : X.t -> W.t -> t -> t + (** [add x w r] adds the tuple [(x, w)] to the relation. *) + + val add_many : X.t -> W.Set.t -> t -> t + + val range : X.t -> t -> W.Set.t + + val remove_dom : X.t -> t -> t + (** [remove_dom x r] removes all tuples of the form [(x, _)] from the + relation. *) + + val remove_range : W.t -> t -> t + (** [remove_range w r] removes all tuples of the form [(_, w)] from the + relation. *) + + val transfer_dom : X.t -> X.t -> t -> t + (** [transfer_dom x x' r] replaces all tuples of the form [(x, w)] in the + relation with the corresponding [(x', w)] tuple. *) + + val iter_range : X.t -> (W.t -> unit) -> t -> unit + (** [iter_range x f r] calls [f] on all the [w] such that [(x, w)] is in the + relation. *) + + val fold_range : X.t -> (W.t -> 'a -> 'a) -> t -> 'a -> 'a + (** [fold_range x f r acc] folds [f] over all the [w] such that [(x, w)] is in + the relation.*) +end = struct + module MX = X.Map + module MW = W.Map + module SX = X.Set + module SW = W.Set + + type t = + { watches : SW.t MX.t ; + (** Reverse map from variables to their watches. Used to trigger watches + when a domain changes. *) + + watching : SX.t MW.t + (** Map from watches to the variables they watch. Used to be able to + remove watches. *) + } + + let range x t = + try MX.find x t.watches with Not_found -> W.Set.empty + + let empty = + { watches = MX.empty + ; watching = MW.empty } + + let add x w t = + let watches = + MX.update x (function + | None -> Some (SW.singleton w) + | Some ws -> Some (SW.add w ws)) t.watches + and watching = + MW.update w (function + | None -> Some (SX.singleton x) + | Some xs -> Some (SX.add x xs)) t.watching + in + { watches ; watching } + + let add_many x ws t = + let watches = + MX.update x (function + | None -> Some ws + | Some ws' -> Some (SW.union ws ws')) t.watches + and watching = + SW.fold (fun w watching -> + MW.update w (function + | None -> Some (SX.singleton x) + | Some xs -> Some (SX.add x xs)) watching + ) ws t.watching in - SX.iter (fun r -> - Ephemeral.set_changed (Ephemeral.handle ephemeral r) - ) t.changed; - ephemeral + { watches ; watching } + + let remove_range w t = + match MW.find w t.watching with + | xs -> + let watches = + SX.fold (fun x watches -> + MX.update x (function + | None -> + (* maps must be mutual inverses *) + assert false + | Some ws -> + let ws = SW.remove w ws in + if SW.is_empty ws then None else Some ws + ) watches + ) xs t.watches + in + let watching = MW.remove w t.watching in + { watches ; watching } + | exception Not_found -> t - let snapshot t = - assert (SX.is_empty t.Ephemeral.persistent.changed); - HX.fold (fun repr handle domains -> - unsafe_update - ~changed:handle.Ephemeral.changed repr handle.domain domains - ) t.Ephemeral.dirty_cache t.persistent + let remove_dom x t = + match MX.find x t.watches with + | ws -> + let watching = + SW.fold (fun w watching -> + MW.update w (function + | None -> + (* maps must be mutual inverses *) + assert false + | Some xs -> + let xs = SX.remove x xs in + if SX.is_empty xs then None else Some xs + ) watching + ) ws t.watching + and watches = MX.remove x t.watches in + { watches ; watching } + | exception Not_found -> t + + let fold_range x f t acc = + match MX.find x t.watches with + | ws -> SW.fold f ws acc + | exception Not_found -> acc + + let iter_range x f t = + match MX.find x t.watches with + | ws -> SW.iter f ws + | exception Not_found -> () + + let transfer_dom x x' t = + match MX.find x t.watches with + | ws -> + let watching = + SW.fold (fun w watching -> + MW.update w (function + | None -> + (* maps must be mutual inverses *) + assert false + | Some xs -> + Some (SX.add x' (SX.remove x xs)) + ) watching + ) ws t.watching + and watches = + MX.update x' (function + | None -> Some ws + | Some ws' -> Some (SW.union ws ws') + ) (MX.remove x t.watches) + in + { watches ; watching } + | exception Not_found -> t end -(** The ['c acts] type is used to register new facts and constraints in - [Constraint.simplify]. *) -type 'c acts = - { acts_add_lit_view : X.r L.view -> unit - (** Assert a semantic literal. *) - ; acts_add_eq : X.r -> X.r -> unit - (** Assert equality between two semantic values. *) - ; acts_add_constraint : 'c -> unit - (** Assert a new constraint. *) - } +(** Implementation of the [ComparableType] interface for semantic values. *) +module XComparable : ComparableType with type t = X.r = struct + type t = X.r -module type Constraint = sig - type t - (** The type of constraints. + let pp = X.print - Constraints apply to semantic values of type [X.r] as arguments. *) + let equal = X.equal - val pp : t Fmt.t - (** Pretty-printer for constraints. *) + let hash = X.hash - val compare : t -> t -> int - (** Comparison function for constraints. The comparison function is - arbitrary and has no semantic meaning. You should not depend on any of - its properties, other than it defines an (arbitrary) total order on - constraint representations. *) + let compare = X.hash_cmp - val fold_args : (X.r -> 'a -> 'a) -> t -> 'a -> 'a - (** [fold_args f c acc] folds function [f] over the arguments of constraint - [c]. + module Set = SX - During propagation, the constraint {b MUST} only look at (and update) - the domains associated of its arguments; it is not allowed to look at - the domains of other semantic values. This allows efficient updates of - the pending constraints. *) + module Map = MX - val subst : X.r -> X.r -> t -> t - (** [subst p v cs] replaces all the instances of [p] with [v] in the - constraint. + module Table = HX +end - Substitution can perform constraint simplification. *) +module type NormalForm = sig + (** Module signature for normal form computation. *) - val simplify : t -> t acts -> bool - (** [simplify c acts] simplifies the constraint [c] by calling appropriate - functions on [acts]. + type constant + (** The type of constant values. *) - {b Note}: All the facts and constraints added through [acts] must be - logically implied by [c] {b only}. Doing otherwise is a {b soundness bug}. + type atom + (** The type of atomic variables that cannot be decomposed further. *) - Returns [true] if the constraint has been fully simplified and can - be removed, and [false] otherwise. + type composite + (** The type of composite variables that are obtained through a combination of + atomic variables (e.g. a multi-variate polynomial). *) - {b Note}: Returning [true] will cause the constraint to be removed, even - if it was re-added with [acts_add_constraint]. If you want to add new - facts/constraints but keep the existing constraint (usually a bad idea), - return [false] instead. *) -end + type t = + | Constant of constant + (** A constant value. *) + | Atom of atom * constant + (** An atomic variable with a constant offset. *) + | Composite of composite * constant + (** A composite variable with a constant offset. *) + (** The type of normal forms. *) -type 'a explained = { value : 'a ; explanation : Explanation.t } + type expr + (** The underlying type of non-normalized expressions. *) -let explained ~ex value = { value ; explanation = ex } + val normal_form : expr -> t + (** [normal_form e] computes the normal form of expression [e]. *) +end -module Constraints_make(Constraint : Constraint) : sig - type t - (** The type of constraint sets. A constraint set records a set of - constraints that applies to semantic values, and remembers the relation - between constraints and semantic values. +module type CompositeType = sig + (** Extension of the [ComparableType] signature for composite types, i.e. + types that are built up from a collection of smaller components. *) - The constraints applying to a given semantic value can be recovered using - the [iter_pending] functions. + include ComparableType - New constraints are marked as "pending" when added to the constraint set - (whether by a call to [add] or following a substitution). These - constraints should ultimately be propagated; they can be accessed through - the [iter_pending]. Once pending constraints have been propagated, the - "pending" constraints should be cleared with [clear_pending]. *) + type atom + (** The type of atoms that build up a composite value. *) - val pp : t Fmt.t - (** Pretty-printer for constraint sets. *) + val fold : (atom -> 'a -> 'a) -> t -> 'a -> 'a + (** [fold f c acc] folds [f] over all the atoms that make up [c]. *) +end - val empty : t - (** The empty constraint set. *) - - val add : ex:Explanation.t -> Constraint.t -> t -> t - (** [add ~ex c t] adds the constraint [c] to the set [t]. - - The explanation [ex] justifies that the constraint [c] holds. If the same - constraint is added multiple times with different explanations, only one - of the explanations for the constraint will be kept. *) - - val subst : ex:Explanation.t -> X.r -> X.r -> t -> t - (** [subst ~ex p v t] replaces all instances of [p] with [v] in the - constraints. - - The explanation [ex] justifies the equality [p = v]. *) - - val iter_parents : (Constraint.t explained -> unit) -> X.r -> t -> unit - (** [iter_parents f r t] calls [f] on all the constraints that apply directly - to [r] (precisely, all the constraints [r] is an argument of). *) - - val iter_pending : (Constraint.t explained -> unit) -> t -> unit - (** [iter_pending f t] calls [f] on all the constraints currently marked as - pending. Constraints are marked as pending when they are added, including - when a new constraint is added due to substitution of an old constraint - (whether the old constraint was pending or not). *) - - val clear_pending : t -> t - (** [clear_pending t] returns a copy of [t] except that no constraints are - marked as pending. *) - - val has_pending : t -> bool - (** [has_pending t] returns [true] if there is any constraint marked as - pending. Hence if [has_pending t] returns [false], [iter_pending] and - [clear_pending] are guaranteed to be no-ops. Should only be used for - optimization. *) - - val fold_args : (X.r -> 'a -> 'a) -> t -> 'a -> 'a - (** [fold_args f t acc] folds [f] over all the term representatives that are - arguments of at least one constraint. *) - - val simplify_pending : - (X.r L.view * Explanation.t) list -> t -> - (X.r L.view * Explanation.t) list * t - (** Simplify the pending constraints. This takes as argument a list of - (explained) literals, and returns a list of (explained) literals, so - that constraint simplification is able to propagate new literals - (typically equalities) to the UF module. *) -end = struct - module CS = Set.Make(struct - type t = Constraint.t explained +module type CompositeDomain = sig + (** Module signature to build a domain for a composite type from the domain of + its component atoms. *) - let compare a b = Constraint.compare a.value b.value - end) + type var + (** The type of (composite) variables. *) - type t = { - args_map : CS.t MX.t ; - (** Mapping from semantic values to constraints involving them *) + type atom + (** The type of atomic variables. *) - leaves_map : CS.t MX.t ; - (** Mapping from semantic values to constraints they are a leaf of *) + type domain + (** The type of domains we are building. *) - active : CS.t ; - (** Set of all currently active constraints, i.e. constraints that must - hold in a model and will be propagated. *) + val map_domain : (atom -> domain) -> var -> domain + (** [map_domain f c] constructs a domain for [c] from a function [f] that + returns the domain of an atom. *) +end - pending : CS.t ; - (** Set of active constraints that have not yet been propagated *) +type ('a, 'c, 'w) events = + { evt_atomic_change : 'a -> unit + ; evt_composite_change : 'c -> unit + (** Called by the ephemeral interface when the domain associated with a + variable changes. *) + ; evt_watch_trigger : 'w -> unit + (** Called by the ephemeral interface when a watcher is triggered. *) } +(** Handlers for events used by the ephemeral interface. *) + +module type VariableType = sig + (** Extension of the [ComparableType] signature for variables that have an + associated type. *) + + include ComparableType + + val type_info : t -> Ty.t + (** [type_info x] returns the type of variable [x]. *) +end + +module Domains_make + (D : OffsetDomain) + (A : VariableType) + (C : CompositeType with type atom = A.t) + (CD : CompositeDomain + with type var = C.t + and type atom = A.t + and type domain = D.t) + (NF : NormalForm + with type atom = A.t + and type composite = C.t + and type constant = D.constant + and type expr = X.r) + (W : OrderedType) + : sig + include Uf.GlobalDomain + + val get : X.r -> t -> D.t + (** [get r t] returns the domain associated with semantic value [r]. *) + + val watch : W.t -> X.r -> t -> t + (** [watch w r t] associated the watch [w] with the domain of semantic value + [r]. The watch [w] is triggered whenever the domain associated with [r] + changes, and is preserved across substitutions (i.e. if [r] becomes + [nr], [w] will be transfered to [nr]). + + {b Note}: The watch [w] is also immediately triggered for a first + propagation. *) + + val unwatch : W.t -> t -> t + (** [unwatch w] removes [w] from all watch lists. It will no longer be + triggered. + + {b Note}: If [w] has already been triggered, it is not removed from the + triggered list. *) + + val needs_propagation : t -> bool + (** Returns [true] if the domains needs propagation, i.e. if any variable's + domain has changed. *) + + val variables : t -> A.Set.t + (** Returns the set of atomic variables that are currently being tracked. *) + + val parents : t -> C.Set.t A.Map.t + (** Returns a map from atomic variables to all the composite variables that + contain them and are currently being tracked. *) + + module Ephemeral : EphemeralDomainMap + with type key = X.r + and type domain = D.t - let pp ppf { active; _ } = - Fmt.( - braces @@ hvbox @@ - iter ~sep:semi CS.iter @@ - using (fun { value; _ } -> value) @@ - box ~indent:2 @@ braces @@ - Constraint.pp - ) ppf active + val edit : events:(A.t, C.t, W.t) events -> t -> Ephemeral.t + (** [edit ~events t] returns an ephemeral copy of the domains for edition. + + The [events] argument is used to notify the caller about domain changes + and watches being triggered. + + {b Note}: Any domain that has changed or watches that have been + triggered through the persistent API (e.g. due to substitutions) are + immediately notified through the appropriare [events] callback. *) + + val snapshot : Ephemeral.t -> t + (** Converts back an ephemeral domain into a persistent one. *) + end += +struct + module DMA = DomainMap(A)(D) + module DMC = DomainMap(C)(D) + + module AW = BinRel(A)(W) + module CW = BinRel(C)(W) + + type t = + { atoms : DMA.t + (* Map from atomic variables to their (non-default) domain. *) + ; atom_watches : AW.t + (* Map (and reverse map) from atomic variables to the watches that must be + triggered when their domain gets updated. *) + ; variables : A.Set.t + (* Set of all atomic variables being tracked. *) + ; composites : DMC.t + (* Map from composite variables to their (non-default) domain. *) + ; composite_watches : CW.t + (* Map (and reverse map) from composite variables to the watches that must + be triggered when their domain gets udpated. *) + ; parents : C.Set.t A.Map.t + (* Reverse map from atomic variables to the composite variables that + contain them. Useful for structural propagation. *) + ; triggers : W.Set.t + (* Watches that have been triggered. They will be immediately notified + when [edit] is called. *) + } + + let pp ppf { atoms ; composites ; _ } = + DMA.pp ppf atoms; + DMC.pp ppf composites let empty = - { args_map = MX.empty - ; leaves_map = MX.empty - ; active = CS.empty - ; pending = CS.empty } - - let cs_add c r cs_map = - MX.update r (function - | Some cs -> Some (CS.add c cs) - | None -> Some (CS.singleton c) - ) cs_map - - let fold_leaves f c acc = - Constraint.fold_args (fun r acc -> - List.fold_left (fun acc r -> f r acc) acc (X.leaves r) - ) c acc - - let add ~ex c t = - let c = explained ~ex c in - (* Note: use [CS.find] here, not [CS.mem], to ensure we use the same - explanation for [c] in the [pending] and [active] sets. *) - if CS.mem c t.active then t else - let active = CS.add c t.active in - let args_map = - Constraint.fold_args (cs_add c) c.value t.args_map - in - let leaves_map = fold_leaves (cs_add c) c.value t.leaves_map in - let pending = CS.add c t.pending in - { active; args_map; leaves_map; pending } - - let cs_remove c r cs_map = - MX.update r (function - | Some cs -> - let cs = CS.remove c cs in - if CS.is_empty cs then None else Some cs - | None -> None - ) cs_map - - let remove c t = - let active = CS.remove c t.active in - let args_map = - Constraint.fold_args (cs_remove c) c.value t.args_map - in - let leaves_map = fold_leaves (cs_remove c) c.value t.leaves_map in - let pending = CS.remove c t.pending in - { active; args_map; leaves_map; pending } + { atoms = DMA.empty + ; atom_watches = AW.empty + ; variables = A.Set.empty + ; composites = DMC.empty + ; composite_watches = CW.empty + ; parents = A.Map.empty + ; triggers = W.Set.empty + } + + type _ Uf.id += Id : t Uf.id + + let filter_ty = D.filter_ty + + exception Inconsistent = D.Inconsistent + + let watch w r t = + let t = { t with triggers = W.Set.add w t.triggers } in + match NF.normal_form r with + | Constant _ -> t + | Atom (a, _) -> + { t with atom_watches = AW.add a w t.atom_watches } + | Composite (c, _) -> + { t with composite_watches = CW.add c w t.composite_watches } + + let unwatch w t = + { atoms = t.atoms + ; atom_watches = AW.remove_range w t.atom_watches + ; variables = t.variables + ; composites = t.composites + ; composite_watches = CW.remove_range w t.composite_watches + ; parents = t.parents + ; triggers = t.triggers } + + let needs_propagation t = + DMA.needs_propagation t.atoms || + DMC.needs_propagation t.composites || + not (W.Set.is_empty t.triggers) + + let variables { variables ; _ } = variables + + let parents { parents ; _ } = parents + + let track c parents = + C.fold (fun a t -> + A.Map.update a (function + | Some cs -> Some (C.Set.add c cs) + | None -> Some (C.Set.singleton c) + ) t + ) c parents + + let untrack c parents = + C.fold (fun a t -> + A.Map.update a (function + | Some cs -> + let cs = C.Set.remove c cs in + if C.Set.is_empty cs then None else Some cs + | None -> None + ) t + ) c parents + + let init r t = + match NF.normal_form r with + | Constant _ -> t + | Atom (a, _) -> + { t with variables = A.Set.add a t.variables } + | Composite (c, _) -> + { t with parents = track c t.parents } + + let default_atom a = D.unknown (A.type_info a) + + let find_or_default_atom a t = + try DMA.find a t.atoms + with Not_found -> default_atom a + + let default_composite c = CD.map_domain default_atom c + + let find_or_default_composite c t = + try DMC.find c t.composites + with Not_found -> default_composite c + + let find_or_default x t = + match x with + | NF.Constant c -> + D.constant c + | NF.Atom (a, o) -> + D.add_offset (find_or_default_atom a t) o + | NF.Composite (c, o) -> + D.add_offset (find_or_default_composite c t) o + + let get r t = find_or_default (NF.normal_form r) t let subst ~ex rr nrr t = - match MX.find rr t.leaves_map with - | cs -> - CS.fold (fun c t -> - let t = remove c t in - let ex = Explanation.union ex c.explanation in - add ~ex (Constraint.subst rr nrr c.value) t - ) cs t - | exception Not_found -> t + let rrd, ws, t = + match NF.normal_form rr with + | Constant _ -> invalid_arg "subst: cannot substitute a constant" + | Atom (a, o) -> + let variables = A.Set.remove a t.variables in + D.add_offset (find_or_default_atom a t) o, + AW.range a t.atom_watches, + { t with + atoms = DMA.remove a t.atoms ; + atom_watches = AW.remove_dom a t.atom_watches ; + variables } + | Composite (c, o) -> + let parents = untrack c t.parents in + D.add_offset (find_or_default_composite c t) o, + CW.range c t.composite_watches, + { t with + composites = DMC.remove c t.composites ; + composite_watches = CW.remove_dom c t.composite_watches ; + parents } + in + (* Add [ex] to justify that it applies to [nrr] *) + let rrd = D.add_explanation ~ex rrd in + let nrr_nf = NF.normal_form nrr in + let nrrd = find_or_default nrr_nf t in + let nnrrd = D.intersect nrrd rrd in + let t = + if D.equal nnrrd rrd then t + else { t with triggers = W.Set.union ws t.triggers } + in + let t = + match nrr_nf with + | Constant _ -> t + | Atom (a, _) -> + let atom_watches = AW.add_many a ws t.atom_watches in + let variables = A.Set.add a t.variables in + { t with atom_watches ; variables } + | Composite (c, _) -> + let composite_watches = CW.add_many c ws t.composite_watches in + let parents = track c t.parents in + { t with composite_watches ; parents } + in + if D.equal nnrrd nrrd then t + else + match nrr_nf with + | Constant _ -> + (* [nrrd] is [D.constant c] which must be a singleton; if we + shrunk it, it can only be empty. *) + assert false + | Atom (a, o) -> + let triggers = W.Set.union (AW.range a t.atom_watches) t.triggers in + let atoms = DMA.add a (D.sub_offset nnrrd o) t.atoms in + { t with atoms ; triggers } + | Composite (c, o) -> + let triggers = + W.Set.union (CW.range c t.composite_watches) t.triggers + in + let composites = DMC.add c (D.sub_offset nnrrd o) t.composites in + { t with composites ; triggers } - let iter_parents f r t = - match MX.find r t.args_map with - | cs -> CS.iter f cs - | exception Not_found -> () + module Ephemeral = struct + type key = X.r + type domain = D.t + + module Entry = struct + type t = + | Constant of NF.constant + | Atom of DMA.Ephemeral.Entry.t * NF.constant + | Composite of DMC.Ephemeral.Entry.t * NF.constant + + let domain = function + | Constant c -> D.constant c + | Atom (a, o) -> + D.add_offset (DMA.Ephemeral.Entry.domain a) o + | Composite (c, o) -> + D.add_offset (DMC.Ephemeral.Entry.domain c) o + + let set_domain e d = + match e with + | Constant _ -> assert false + | Atom (a, o) -> + DMA.Ephemeral.Entry.set_domain a (D.sub_offset d o) + | Composite (c, o) -> + DMC.Ephemeral.Entry.set_domain c (D.sub_offset d o) + end + + type t = + { atoms : DMA.Ephemeral.t + ; atom_watches : AW.t + ; variables : A.Set.t + ; composites : DMC.Ephemeral.t + ; composite_watches : CW.t + ; parents : C.Set.t A.Map.t } + + let entry t r = + match NF.normal_form r with + | NF.Constant c -> + Entry.Constant c + | NF.Atom (a, o) -> + Atom (DMA.Ephemeral.entry t.atoms a, o) + | NF.Composite (c, o) -> + Entry.Composite (DMC.Ephemeral.entry t.composites c, o) + end - let iter_pending f t = - CS.iter f t.pending - - let clear_pending t = - { t with pending = CS.empty } - - let has_pending t = not @@ CS.is_empty t.pending - - let fold_args f c acc = - MX.fold (fun r _ acc -> - f r acc - ) c.args_map acc - - let simplify_pending = - (* Recursion needed because adding new constraints changes the pending set - and they also need to be simplified *) - let rec simplify_aux eqs t to_simplify = - let eqs = ref eqs in - let to_add = ref CS.empty in - let t = - CS.fold (fun ({ value; explanation } as c) t -> - let acts_add_lit_view l = - eqs := (l, explanation) :: !eqs - in - let acts_add_eq u v = - acts_add_lit_view (Uf.LX.mkv_eq u v) - in - let acts_add_constraint c = - let c = { value = c; explanation } in - if not (CS.mem c t.active) then - to_add := CS.add c !to_add - in - let acts = - { acts_add_lit_view - ; acts_add_eq - ; acts_add_constraint } in - if Constraint.simplify value acts then - remove c t - else - t - ) to_simplify t - in - let to_add = !to_add in - if CS.is_empty to_add then - !eqs, t - else - let t = CS.fold (fun c t -> add ~ex:c.explanation c.value t) to_add t in - simplify_aux !eqs t to_add + let edit ~events t = + W.Set.iter events.evt_watch_trigger t.triggers; + + let notify_atom a = + events.evt_atomic_change a; + AW.iter_range a events.evt_watch_trigger t.atom_watches + and notify_composite c = + events.evt_composite_change c; + CW.iter_range c events.evt_watch_trigger t.composite_watches in - fun eqs t -> - if CS.is_empty t.pending then eqs, t else - simplify_aux eqs t t.pending + + { Ephemeral.atoms = + DMA.edit + ~notify:notify_atom ~default:default_atom + t.atoms + ; atom_watches = t.atom_watches + ; variables = t.variables + ; composites = + DMC.edit + ~notify:notify_composite ~default:default_composite + t.composites + ; composite_watches = t.composite_watches + ; parents = t.parents } + + let snapshot t = + { atoms = DMA.snapshot t.Ephemeral.atoms + ; atom_watches = t.Ephemeral.atom_watches + ; variables = t.Ephemeral.variables + ; composites = DMC.snapshot t.Ephemeral.composites + ; composite_watches = t.Ephemeral.composite_watches + ; parents = t.Ephemeral.parents + ; triggers = W.Set.empty } +end + +(** Wrapper around an ephemeral domain map to access domains associated with a + representative computed by the [Uf] module. *) +module UfHandle + (D : Domain) + (DM : EphemeralDomainMap with type key = X.r and type domain = D.t) + : sig + include EphemeralDomainMap with type key = X.r and type domain = D.t + + val wrap : Uf.t -> DM.t -> t + end += +struct + type key = X.r + + type domain = DM.domain + + module Entry = struct + type t = + { repr : X.r + ; handle : DM.Entry.t + ; explanation : Explanation.t } + + let domain { repr ; handle ; explanation = ex } = + if Explanation.is_empty ex then DM.Entry.domain handle + else + D.intersect (D.unknown (X.type_info repr)) @@ + D.add_explanation ~ex (DM.Entry.domain handle) + + let set_domain { handle ; explanation = ex ; _ } d = + DM.Entry.set_domain handle (D.add_explanation ~ex d) + end + + type t = + { uf : Uf.t + ; cache : Entry.t HX.t + ; domains : DM.t } + + let entry t r = + try HX.find t.cache r with Not_found -> + let r, explanation = Uf.find_r t.uf r in + let h = + { Entry.repr = r + ; handle = DM.entry t.domains r + ; explanation } + in + HX.replace t.cache r h; h + + let wrap uf t = + { uf ; cache = HX.create 17 ; domains = t } +end + +module HandleNotations + (D : Domain) + (E : EphemeralDomainMap with type domain = D.t) = +struct + let (!!) = E.Entry.domain + + let update ~ex entry domain = + let current = E.Entry.domain entry in + let domain = D.intersect current (D.add_explanation ~ex domain) in + if not (D.equal domain current) then + E.Entry.set_domain entry domain end diff --git a/src/lib/reasoners/uf.ml b/src/lib/reasoners/uf.ml index d1a4d2e30a..5c26bc7e53 100644 --- a/src/lib/reasoners/uf.ml +++ b/src/lib/reasoners/uf.ml @@ -88,7 +88,7 @@ module type GlobalDomain = sig val filter_ty : Ty.t -> bool - val add : X.r -> t -> t + val init : X.r -> t -> t exception Inconsistent of Explanation.t @@ -135,7 +135,7 @@ module GlobalDomains = struct let init r t = let ty = X.type_info r in MapI.map (function B ((module D) as dom, d) as b -> - if D.filter_ty ty then B (dom, D.add r d) else b + if D.filter_ty ty then B (dom, D.init r d) else b ) t let add (type a) ((module D) as dom : a global_domain) v t = diff --git a/src/lib/reasoners/uf.mli b/src/lib/reasoners/uf.mli index 8ffd416fb8..1aca26260c 100644 --- a/src/lib/reasoners/uf.mli +++ b/src/lib/reasoners/uf.mli @@ -68,8 +68,8 @@ module type GlobalDomain = sig of representatives for which [filter_ty (type_info r)] holds will be propagated to this module. *) - val add : r -> t -> t - (** [add r t] is called when the representative [r] is added to the + val init : r -> t -> t + (** [init r t] is called when the representative [r] is added to the union-find, if it has a type that matches [filter_ty]. {b Note}: unlike [Relation.add], this function is called even for