From 5c14a39928549d0dc65ae55e2f82985e95d5f573 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Basile=20Cl=C3=A9ment?= <129742207+bclement-ocp@users.noreply.github.com> Date: Fri, 21 Jun 2024 10:02:57 +0200 Subject: [PATCH] feat(BV): Interval domains for bit-vectors (#1058) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(BV): Interval domains for bit-vectors This patch adds interval domains to the Bitv_rel module, as well as interreductions between the bitlist and interval domains following: Sharpening Constraint Programming approaches for Bit-Vector Theory. Zakaria Chihani, Bruno Marre, François Bobot, Sébastien Bardin. CPAIOR 2017. International Conference on AI and OR Techniques in Constraint Programming for Combinatorial Optimization Problems, Jun 2017, Padova, Italy. More precisely: - The `Intervals` module is extended to support the `extract` operation, which is used to propagate between bit-vector compositions and their components; - The interreductions are implemented using the new `Bitlist.increase_lower_bound`, `Bitlist.decrease_upper_bound`, and the new `shared_msb` helper in `Bitv_rel`; - Propagations are performed by alternating propagations until fixpoint in each domain, followed by interreductions and propagations until fixpoint in the other domain, until reaching a fixpoint for the whole procedure. It is not clear that this is the best strategy; the goal is to try and limit interreductions since they are relatively expensive but we should revisit this once more operations are supported. For now, only the `bvule`, `bvult`, `bvugt` and `bvuge` primitives are supported as built-in bit-vector operations; other operations such as `bvadd` are still encoded using `bv2nat`. These operations will be migrated to bit-vector primitives in a follow-up PR. Finally, there are some tests for the tricky bits (`Intervals.extract` and the interreduction primitives) using QCheck. --- alt-ergo-lib.opam | 1 + alt-ergo-lib.opam.locked | 1 + dune-project | 1 + shell.nix | 1 + src/lib/frontend/models.ml | 15 +- src/lib/reasoners/bitlist.ml | 102 +++++ src/lib/reasoners/bitlist.mli | 20 + src/lib/reasoners/bitv_rel.ml | 570 +++++++++++++++++++++--- src/lib/reasoners/bitv_rel.mli | 4 + src/lib/reasoners/intervals.ml | 64 ++- src/lib/reasoners/intervals.mli | 37 +- src/lib/reasoners/intervals_intf.ml | 6 + src/lib/reasoners/rel_utils.ml | 10 +- src/lib/structures/expr.ml | 22 +- src/lib/structures/symbols.ml | 5 +- src/lib/structures/symbols.mli | 1 + src/lib/structures/xliteral.ml | 10 + src/lib/structures/xliteral.mli | 1 + tests/bitv/bitlist-interval001.expected | 2 + tests/bitv/bitlist-interval001.smt2 | 5 + tests/bitvec_tests.ml | 231 ++++++++++ tests/dune | 6 + tests/dune.inc | 251 ++++++++++- 23 files changed, 1280 insertions(+), 86 deletions(-) create mode 100644 tests/bitv/bitlist-interval001.expected create mode 100644 tests/bitv/bitlist-interval001.smt2 create mode 100644 tests/bitvec_tests.ml diff --git a/alt-ergo-lib.opam b/alt-ergo-lib.opam index a2d48a46d0..db14b35520 100644 --- a/alt-ergo-lib.opam +++ b/alt-ergo-lib.opam @@ -31,6 +31,7 @@ depends: [ "odoc" {with-doc} "ppx_deriving" "stdcompat" + "qcheck" {with-test} ] conflicts: [ "ppxlib" {< "0.30.0"} diff --git a/alt-ergo-lib.opam.locked b/alt-ergo-lib.opam.locked index 0def9a7a84..a9a2593cd0 100644 --- a/alt-ergo-lib.opam.locked +++ b/alt-ergo-lib.opam.locked @@ -53,6 +53,7 @@ depends: [ "ppx_derivers" {= "1.2.1"} "ppx_deriving" {= "5.2.1"} "ppxlib" {= "0.31.0"} + "qcheck" {= "0.21.3"} "result" {= "1.5"} "seq" {= "base"} "sexplib0" {= "v0.16.0"} diff --git a/dune-project b/dune-project index 080b1bae9e..8cac728ce8 100644 --- a/dune-project +++ b/dune-project @@ -91,6 +91,7 @@ See more details on http://alt-ergo.ocamlpro.com/" (odoc :with-doc) ppx_deriving stdcompat + (qcheck :with-test) ) (conflicts (ppxlib (< 0.30.0)) diff --git a/shell.nix b/shell.nix index 6311586340..b4438ffb35 100644 --- a/shell.nix +++ b/shell.nix @@ -43,5 +43,6 @@ pkgs.mkShell { stdcompat landmarks landmarks-ppx + qcheck ]; } diff --git a/src/lib/frontend/models.ml b/src/lib/frontend/models.ml index 2cf7cb2460..7b39d1bc6f 100644 --- a/src/lib/frontend/models.ml +++ b/src/lib/frontend/models.ml @@ -105,6 +105,18 @@ module Pp_smtlib_term = struct | Sy.L_neg_pred, [a] -> fprintf fmt "(not %a)" print a + | Sy.L_built Sy.BVULE, [a;b] -> + if Options.get_output_smtlib () then + fprintf fmt "(bvule %a %a)" print a print b + else + fprintf fmt "(%a <= %a)" print a print b + + | Sy.L_neg_built Sy.BVULE, [a;b] -> + if Options.get_output_smtlib () then + fprintf fmt "(bvugt %a %a)" print a print b + else + fprintf fmt "(%a > %a)" print a print b + | Sy.L_built (Sy.IsConstr hs), [e] -> if Options.get_output_smtlib () then fprintf fmt "((_ is %a) %a)" Uid.pp hs print e @@ -117,7 +129,8 @@ module Pp_smtlib_term = struct else fprintf fmt "not (%a ? %a)" print e Uid.pp hs - | (Sy.L_built (Sy.LT | Sy.LE) | Sy.L_neg_built (Sy.LT | Sy.LE) + | (Sy.L_built (Sy.LT | Sy.LE | Sy.BVULE) + | Sy.L_neg_built (Sy.LT | Sy.LE | Sy.BVULE) | Sy.L_neg_pred | Sy.L_eq | Sy.L_neg_eq | Sy.L_built (Sy.IsConstr _) | Sy.L_neg_built (Sy.IsConstr _)) , _ -> diff --git a/src/lib/reasoners/bitlist.ml b/src/lib/reasoners/bitlist.ml index c248110cad..c45535fc89 100644 --- a/src/lib/reasoners/bitlist.ml +++ b/src/lib/reasoners/bitlist.ml @@ -162,3 +162,105 @@ let logxor b1 b2 = ; bits_clr ; ex = Ex.union b1.ex b2.ex } + +(* The logic for the [increase_lower_bound] function below is described in + section 4.1 of + + Sharpening Constraint Programming approaches for Bit-Vector Theory. + Zakaria Chihani, Bruno Marre, François Bobot, Sébastien Bardin. + CPAIOR 2017. International Conference on AI and OR Techniques in + Constraint Programming for Combinatorial Optimization Problems, Jun + 2017, Padova, Italy. + https://cea.hal.science/cea-01795779/document *) + +(* [left_cl_can_set highest_cleared cleared_can_set] returns the + least-significant bit that is: + - More significant than [highest_cleared], strictly; + - Set in [cleared_can_set] *) +let left_cl_can_set highest_cleared cleared_can_set = + let can_set = Z.(cleared_can_set asr highest_cleared) in + highest_cleared + Z.trailing_zeros can_set + +let increase_lower_bound b lb = + (* [r] is the new candidate lower bound; we only keep the *unknown* bits of + [lb] and otherwise use the known bits from the domain [b]. + + [cleared_bits] contains the bits that were set in [lb] and got cleared in + [r]; conversely, [set_bits] contains the bits that were cleared in [lb] and + got set in [r]. *) + let r = Z.logor b.bits_set (Z.logand lb (Z.lognot b.bits_clr)) in + let cleared_bits = Z.logand lb (Z.lognot r) in + let set_bits = Z.logand (Z.lognot lb) r in + + (* We now look at the most-significant bit that was changed (since [set_bits] + and [cleared_bits] have disjoint bits set, comparing them is equivalent to + comparing their most significant bit). *) + let c = Z.compare set_bits cleared_bits in + if c > 0 then ( + (* [set_bits > cleared_bits] means that the most-significant changed bit + was 0, and is now 1. + + Any higher bits are unchanged, but all lower bits that are not forced + must be cleared (for instance we can only increase 0b010 to 0b100; + increasing it to 0b110 would be incorrect). + + The following clears any lower bits ([Z.numbits set_bits] is the + most-significant bit that was set), unless they are forced to 1. *) + let bit_to_set = Z.numbits set_bits in + let mask = Z.(minus_one lsl bit_to_set) in + Z.logand r @@ Z.logor mask b.bits_set + ) else if c = 0 then ( + (* [set_bits] and [cleared_bits] can only be equal if they are both zero, + because no bit can go from 0 to 1 *and* from 1 to 0 at the same time. *) + assert (Z.equal set_bits Z.zero); + assert (Z.equal r lb); + lb + ) else ( + (* [cleared_bits > set_bits] means that the most-significant changed bit was + 1, and is now 0. To achieve this while increasing the value, we need to + set a higher bit from 0 to 1, and it needs to be the *lowest* bit that is + higher than the most-significant changed bit. + + For instance to clear 0b01[1]011 we need to go to 0b100000. + + Once we found that bit (done by [left_cl_can_set]), we do the same thing + as when the most-significant changed bit was 0 and is now 1 (see [if] + case above). *) + let bit_to_clear = Z.numbits cleared_bits in + let cleared_can_set = Z.lognot @@ Z.logor r b.bits_clr in + let bit_to_set = left_cl_can_set bit_to_clear cleared_can_set in + if bit_to_set >= b.width then + raise Not_found; + let r = Z.logor r Z.(~$1 lsl bit_to_set) in + let mask = Z.(minus_one lsl bit_to_set) in + Z.logand r @@ Z.logor mask b.bits_set + ) + +let decrease_upper_bound b ub = + (* x <= ub <-> ~ub <= ~x *) + let sz = width b in + assert (Z.numbits ub <= sz); + let nub = + increase_lower_bound (lognot b) (Z.extract (Z.lognot ub) 0 sz) + in + Z.extract (Z.lognot nub) 0 sz + +let fold_domain f b acc = + if b.width <= 0 then + invalid_arg "Bitlist.fold_domain"; + let rec fold_domain_aux ofs b acc = + if ofs >= b.width then ( + assert (is_fully_known b); + f (value b) acc + ) else if Z.testbit b.bits_clr ofs || Z.testbit b.bits_set ofs then + fold_domain_aux (ofs + 1) b acc + else + let mask = Z.(one lsl ofs) in + let acc = + fold_domain_aux + (ofs + 1) { b with bits_clr = Z.logor b.bits_clr mask } acc + in + fold_domain_aux + (ofs + 1) { b with bits_set = Z.logor b.bits_set mask } acc + in + fold_domain_aux 0 b acc diff --git a/src/lib/reasoners/bitlist.mli b/src/lib/reasoners/bitlist.mli index fa7a935bd2..66c73aa703 100644 --- a/src/lib/reasoners/bitlist.mli +++ b/src/lib/reasoners/bitlist.mli @@ -128,3 +128,23 @@ val extract : t -> int -> int -> t (** [extract b i j] returns the bitlist from index [i] to index [j] inclusive. The resulting bitlist has length [j - i + 1]. *) + +val increase_lower_bound : t -> Z.t -> Z.t +(** [increase_lower_bound b lb] returns the smallest integer [lb' >= lb] that + matches the bit-pattern in [b]. + + @raise Not_found if no such integer exists. *) + +val decrease_upper_bound : t -> Z.t -> Z.t +(** [decrease_upper_bound b ub] returns the largest integer [ub' >= ub] that + matches the bit-pattern in [b]. + + @raise Not_found if no such integer exists. *) + +(**/**) + +(** [fold_finite_domain f i acc] accumulates [f] on all the elements of [i] (in + an unspecified order). Intended for testing purposes only. + + @raise Invalid_argument if the bitlist is [empty]. *) +val fold_domain : (Z.t -> 'a -> 'a) -> t -> 'a -> 'a diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index 8f85929b72..6777b492c1 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -75,7 +75,94 @@ let is_bv_ty = function let is_bv_r r = is_bv_ty @@ X.type_info r -module Domain : Rel_utils.Domain with type t = Bitlist.t = struct +module Interval_domain : Rel_utils.Domain with type t = Intervals.Int.t = struct + type t = Intervals.Int.t + + let equal = Intervals.Int.equal + + let pp = Intervals.Int.pp + + exception Inconsistent of Explanation.t + + let add_explanation = Intervals.Int.add_explanation + + let unknown = function + | Ty.Tbitv n -> + Intervals.Int.of_bounds + (Closed Z.zero) (Open Z.(~$1 lsl n)) + | ty -> + Fmt.invalid_arg "unknown: only bit-vector types are supported; got %a" + Ty.print ty + + let filter_ty = is_bv_ty + + let intersect x y = + match Intervals.Int.intersect x y with + | Empty ex -> + raise @@ Inconsistent ex + | NonEmpty u -> u + + let lognot sz int = + Intervals.Int.extract ~ofs:0 ~len:sz @@ Intervals.Int.lognot int + + let fold_signed f { Bitv.value; negated } sz int acc = + f value (if negated then lognot sz int else int) acc + + let point ?ex n = + Intervals.Int.of_bounds ?ex (Closed n) (Closed n) + + let fold_leaves f r int acc = + let width = + match X.type_info r with + | Tbitv n -> n + | _ -> assert false + in + let j, acc = + List.fold_left (fun (j, acc) { Bitv.bv; sz } -> + (* sz = j - i + 1 => i = j - sz + 1 *) + let int = Intervals.Int.extract int ~ofs:(j - sz + 1) ~len:sz in + + let acc = match bv with + | Bitv.Cte z -> + (* Nothing to update, but still check for consistency *) + ignore @@ intersect int (point z); + acc + | Other s -> fold_signed f s sz int acc + | Ext (r, r_size, i, j) -> + (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) + assert (i + r_size - j - 1 + sz = r_size); + let lo = unknown (Tbitv i) in + let int = Intervals.Int.scale Z.(~$1 lsl i) int in + let hi = unknown (Tbitv (r_size - j - 1)) in + let hi = + Intervals.Int.scale Z.(~$1 lsl Stdlib.(i + sz)) hi + in + fold_signed f r r_size Intervals.Int.(add hi (add int lo)) acc + in + (j - sz), acc + ) (width - 1, acc) (Shostak.Bitv.embed r) + in + assert (j = -1); + acc + + let map_signed f { Bitv.value; negated } sz = + if negated then lognot sz (f value) else f value + + let map_leaves f r = + List.fold_left (fun ival { Bitv.bv; sz } -> + let ival = Intervals.Int.scale Z.(~$1 lsl sz) ival in + Intervals.Int.add ival @@ + match bv with + | Bitv.Cte z -> point z + | Other s -> map_signed f s sz + | Ext (s, sz', i, j) -> + Intervals.Int.extract (map_signed f s sz') ~ofs:i ~len:(j - i + 1) + ) (point Z.zero) (Shostak.Bitv.embed r) +end + +module Interval_domains = Rel_utils.Domains_make(Interval_domain) + +module Bitlist_domain : Rel_utils.Domain with type t = Bitlist.t = struct (* Note: these functions are not in [Bitlist] proper in order to avoid a (direct) dependency from [Bitlist] to the [Shostak] module. *) @@ -111,17 +198,17 @@ module Domain : Rel_utils.Domain with type t = Bitlist.t = struct fold_signed f r (hi @ bl @ lo) acc, bl_tail ) (acc, bl) (Shostak.Bitv.embed r) - let map_signed f { Bitv.value; negated } t = - let bl = f value t in + let map_signed f { Bitv.value; negated } = + let bl = f value in if negated then lognot bl else bl - let map_leaves f r acc = + let map_leaves f r = List.fold_left (fun bl { Bitv.bv; sz } -> concat bl @@ match bv with | Bitv.Cte z -> exact sz z Ex.empty - | Other r -> map_signed f r acc - | Ext (r, _r_size, i, j) -> extract (map_signed f r acc) i j + | Other r -> map_signed f r + | Ext (r, _r_size, i, j) -> extract (map_signed f r) i j ) empty (Shostak.Bitv.embed r) let unknown = function @@ -131,7 +218,7 @@ module Domain : Rel_utils.Domain with type t = Bitlist.t = struct invalid_arg "unknown" end -module Domains = Rel_utils.Domains_make(Domain) +module Bitlist_domains = Rel_utils.Domains_make(Bitlist_domain) module Constraint : sig include Rel_utils.Constraint @@ -149,11 +236,17 @@ module Constraint : sig val bvxor : X.r -> X.r -> X.r -> t (** [bvxor x y z] is the constraint [x ^ y ^ z = 0] *) - val propagate : ex:Ex.t -> t -> Domains.Ephemeral.t -> unit + val bvule : X.r -> X.r -> t + + val bvugt : X.r -> X.r -> t + + val propagate_bitlist : ex:Ex.t -> t -> Bitlist_domains.Ephemeral.t -> unit (** [propagate ~ex t dom] propagates the constraint [t] in domain [dom]. The explanation [ex] justifies that the constraint [t] applies, and must be added to any domain that gets updated during propagation. *) + + val propagate_interval : ex:Ex.t -> t -> Interval_domains.Ephemeral.t -> unit end = struct type binop = (* Bitwise operations *) @@ -176,7 +269,7 @@ end = struct | Band | Bor | Bxor -> true let propagate_binop ~ex dx op dy dz = - let open Domains.Ephemeral in + let open Bitlist_domains.Ephemeral in match op with | Band -> update ~ex dx (Bitlist.logand !!dy !!dz); @@ -204,6 +297,10 @@ end = struct update ~ex dy (Bitlist.logxor !!dx !!dz); update ~ex dz (Bitlist.logxor !!dx !!dy) + let propagate_interval_binop ~ex:_ _r _op _y _z = + (* No interval propagation for binops yet *) + () + type fun_t = | Fbinop of binop * X.r * X.r @@ -232,39 +329,136 @@ end = struct | Fbinop (op, x, y) -> Fbinop (op, X.subst rr nrr x, X.subst rr nrr y) let propagate_fun_t ~ex dom r f = - let open Domains.Ephemeral in + let open Bitlist_domains.Ephemeral in let get r = handle dom r in match f with | Fbinop (op, x, y) -> propagate_binop ~ex (get r) op (get x) (get y) + let propagate_interval_fun_t ~ex dom r f = + let get r = Interval_domains.Ephemeral.handle dom r in + match f with + | Fbinop (op, x, y) -> + propagate_interval_binop ~ex (get r) op (get x) (get y) + + type binrel = Rule | Rugt + + let pp_binrel ppf = function + | Rule -> Fmt.pf ppf "bvule" + | Rugt -> Fmt.pf ppf "bvugt" + + let equal_binrel : binrel -> binrel -> bool = Stdlib.(=) + + let hash_binrel : binrel -> int = Hashtbl.hash + + let propagate_binrel ~ex:_ _op _dx _dy = + (* No bitlist propagation for relations yet *) + () + + let less_than_sup ~strict iv = + let sup, ex = Intervals.Int.upper_bound iv in + let sup = if strict then Intervals.map_bound Z.pred sup else sup in + Intervals.Int.of_bounds ~ex Unbounded sup + + let greater_than_inf ~strict iv = + let inf, ex = Intervals.Int.lower_bound iv in + let inf = if strict then Intervals.map_bound Z.succ inf else inf in + Intervals.Int.of_bounds ~ex inf Unbounded + + let propagate_less_than ~ex ~strict dx dy = + let open Interval_domains.Ephemeral in + update ~ex dx (less_than_sup ~strict !!dy); + update ~ex dy (greater_than_inf ~strict !!dx) + + let propagate_interval_binrel ~ex op dx dy = + match op with + | Rule -> + propagate_less_than ~ex ~strict:false dx dy + | Rugt -> + propagate_less_than ~ex ~strict:true dy dx + + type rel_t = + | Rbinrel of binrel * X.r * X.r + + let pp_rel_t ppf = function + | Rbinrel (op, x, y) -> + Fmt.pf ppf "%a@[(%a,@ %a)@]" pp_binrel op X.print x X.print y + + let equal_rel_t f1 f2 = + match f1, f2 with + | Rbinrel (op1, x1, y1), Rbinrel (op2, x2, y2) -> + equal_binrel op1 op2 && X.equal x1 x2 && X.equal y1 y2 + + let hash_rel_t = function + | Rbinrel (op, x, y) -> Hashtbl.hash (hash_binrel op, X.hash x, X.hash y) + + let normalize_rel_t = function + (* No normalization for relations yet *) + | r -> r + + let fold_args_rel_t f r acc = + match r with + | Rbinrel (_op, x, y) -> f y (f x acc) + + let subst_rel_t rr nrr = function + | Rbinrel (op, x, y) -> Rbinrel (op, X.subst rr nrr x, X.subst rr nrr y) + + let propagate_rel_t ~ex dom r = + let open Bitlist_domains.Ephemeral in + let get r = handle dom r in + match r with + | Rbinrel (op, x, y) -> + propagate_binrel ~ex op (get x) (get y) + + let propagate_interval_rel_t ~ex dom r = + let get r = Interval_domains.Ephemeral.handle dom r in + match r with + | Rbinrel (op, x, y) -> + propagate_interval_binrel ~ex op (get x) (get y) + type repr = | Cfun of X.r * fun_t + | Crel of rel_t let pp_repr ppf = function | Cfun (r, fn) -> Fmt.(pf ppf "%a =@ %a" (box X.print) r (box pp_fun_t) fn) + | Crel rel -> + pp_rel_t ppf rel let equal_repr c1 c2 = match c1, c2 with | Cfun (r1, f1), Cfun (r2, f2) -> X.equal r1 r2 && equal_fun_t f1 f2 + | Cfun _, _ | _, Cfun _ -> false + + | Crel r1, Crel r2 -> + equal_rel_t r1 r2 let hash_repr = function - | Cfun (r, f) -> Hashtbl.hash (X.hash r, hash_fun_t f) + | Cfun (r, f) -> Hashtbl.hash (0, X.hash r, hash_fun_t f) + | Crel r -> Hashtbl.hash (1, hash_rel_t r) let normalize_repr = function | Cfun (r, f) -> Cfun (r, normalize_fun_t f) + | Crel r -> Crel (normalize_rel_t r) let fold_args_repr f c acc = match c with | Cfun (r, fn) -> fold_args_fun_t f fn (f r acc) + | Crel r -> fold_args_rel_t f r acc let subst_repr rr nrr = function | Cfun (r, f) -> Cfun (X.subst rr nrr r, subst_fun_t rr nrr f) + | Crel r -> Crel (subst_rel_t rr nrr r) let propagate_repr ~ex dom = function | Cfun (r, f) -> propagate_fun_t ~ex dom r f + | Crel r -> propagate_rel_t ~ex dom r + + let propagate_interval_repr ~ex dom = function + | Cfun (r, f) -> propagate_interval_fun_t ~ex dom r f + | Crel r -> propagate_interval_rel_t ~ex dom r type t = { repr : repr ; mutable tag : int } @@ -298,6 +492,13 @@ end = struct let bvor = cbinop Bor let bvxor = cbinop Bxor + let crel r = hcons @@ Crel r + + let cbinrel op x y = crel (Rbinrel (op, x, y)) + + let bvule = cbinrel Rule + let bvugt = cbinrel Rugt + let equal c1 c2 = c1.tag = c2.tag let hash c = Hashtbl.hash c.tag @@ -309,9 +510,12 @@ end = struct let subst rr nrr c = hcons @@ subst_repr rr nrr c.repr - let propagate ~ex c dom = + let propagate_bitlist ~ex c dom = propagate_repr ~ex dom c.repr + let propagate_interval ~ex c dom = + propagate_interval_repr ~ex dom c.repr + let simplify_binop acts op r x y = let acts_add_zero r = let sz = match X.type_info r with Tbitv n -> n | _ -> assert false in @@ -335,8 +539,19 @@ end = struct let simplify_fun_t acts r = function | Fbinop (op, x, y) -> simplify_binop acts op r x y + let simplify_binrel acts op x y = + match op with + | Rugt when X.equal x y -> + acts.Rel_utils.acts_add_eq X.top X.bot; + true + | Rule | Rugt -> false + + let simplify_rel_t acts = function + | Rbinrel (op, x, y) -> simplify_binrel acts op x y + let simplify_repr acts = function | Cfun (r, f) -> simplify_fun_t acts r f + | Crel r -> simplify_rel_t acts r let simplify c acts = simplify_repr acts c.repr @@ -416,7 +631,7 @@ module Any_constraint = struct | Constraint of Constraint.t Rel_utils.explained | Structural of X.r (** Structural constraint associated with [X.r]. See - {!Rel_utils.Domains.structural_propagation}. *) + {!Rel_utils.Bitlist_domains.structural_propagation}. *) let equal a b = match a, b with @@ -428,55 +643,256 @@ module Any_constraint = struct | Constraint c -> 2 * Constraint.hash c.value | Structural r -> 2 * X.hash r + 1 - let propagate c d = + let propagate constraint_propagate structural_propagation c d = match c with | Constraint { value; explanation = ex } -> - Constraint.propagate ~ex value d + constraint_propagate ~ex value d | Structural r -> - Domains.Ephemeral.structural_propagation d r + structural_propagation d r end module QC = Uqueue.Make(Any_constraint) -(* Propagate: - - - The constraints that were never propagated since they were added - - The constraints involving variables whose domain changed since the last - propagation - - Iterate until fixpoint is reached. *) -let propagate eqs bcs dom = +(* Compute the number of most significant bits shared by [inf] and [sup]. + + Requires: [inf <= sup] + Ensures: + result is the greatest integer <= sz such that + inf = sup + + In particular, [result = sz] iff [inf = sup] and [result = 0] iff the most + significant bits of [inf] and [sup] differ. *) +let rec shared_msb sz inf sup = + let numbits_inf = Z.numbits inf in + let numbits_sup = Z.numbits sup in + assert (numbits_inf <= numbits_sup); + if numbits_inf = numbits_sup then + (* Top [sz - numbits_inf] bits are 0 in both; look at 1s *) + if numbits_inf = 0 then + sz + else + sz - numbits_inf + + shared_msb numbits_inf + (Z.extract (Z.lognot sup) 0 numbits_inf) + (Z.extract (Z.lognot inf) 0 numbits_inf) + else + (* Top [sz - numbits_sup] are 0 in both, the next significant bit differs *) + sz - numbits_sup + +let finite_lower_bound = function + | Intervals_intf.Unbounded -> Z.zero + | Closed n -> n + | Open n -> Z.succ n + +let finite_upper_bound ~size:sz = function + | Intervals_intf.Unbounded -> Z.extract Z.minus_one 0 sz + | Closed n -> n + | Open n -> Z.pred n + +(* If m and M are the minimal and maximal values of an union of intervals, the + longest sequence of most significant bits shared between m and M can be fixed + in the bit-vector domain; see "Is to BVs" in section 4.1 of + + Sharpening Constraint Programming approaches for Bit-Vector Theory. + Zakaria Chihani, Bruno Marre, François Bobot, Sébastien Bardin. + CPAIOR 2017. International Conference on AI and OR Techniques in + Constraint Programming for Combinatorial Optimization Problems, Jun + 2017, Padova, Italy. + https://cea.hal.science/cea-01795779/document + + Relevant excerpt: + + For example, m = 48 and M = 52 (00110000 and 00110100 in binary) share their + five most-significant bits, denoted [00110???]. Therefore, a bit-vector bl = + [0??1???0] can be refined into [00110??0]. *) +let constrain_bitlist_from_interval bv int = + let open Bitlist_domains.Ephemeral in + let sz = Bitlist.width !!bv in + + let inf, inf_ex = Intervals.Int.lower_bound int in + let inf = finite_lower_bound inf in + let sup, sup_ex = Intervals.Int.upper_bound int in + let sup = finite_upper_bound ~size:sz sup in + + let nshared = shared_msb sz inf sup in + if nshared > 0 then + let ex = Ex.union inf_ex sup_ex in + let shared_bl = + Bitlist.exact nshared (Z.extract inf (sz - nshared) nshared) ex + in + update ~ex bv @@ + Bitlist.concat shared_bl (Bitlist.unknown (sz - nshared) Ex.empty) + +(* Algorithm 1 from + + Sharpening Constraint Programming approaches for Bit-Vector Theory. + Zakaria Chihani, Bruno Marre, François Bobot, Sébastien Bardin. + CPAIOR 2017. International Conference on AI and OR Techniques in + Constraint Programming for Combinatorial Optimization Problems, Jun + 2017, Padova, Italy. + https://cea.hal.science/cea-01795779/document + + This function is a wrapper calling [Bitlist.increase_lower_bound] and + [Bitlist.decrease_upper_bound] on all the constituent intervals of an union; + see the documentation of these functions for details. *) +let constrain_interval_from_bitlist int bv = + let open Interval_domains.Ephemeral in + let ex = Bitlist.explanation bv in + (* Handy wrapper around [of_complement] *) + let remove ~ex i2 i1 = + match Intervals.Int.of_complement ~ex i1 with + | Empty _ -> invalid_arg "remove" + | NonEmpty i1 -> + match Intervals.Int.intersect i2 i1 with + | NonEmpty i -> i + | Empty ex -> + raise @@ Interval_domains.Inconsistent ex + in + update ~ex int @@ + Intervals.Int.fold (fun acc i -> + let { Intervals_intf.lb ; ub } = Intervals.Int.Interval.view i in + let lb = finite_lower_bound lb in + let ub = finite_upper_bound ~size:(Bitlist.width bv) ub in + let acc = + match Bitlist.increase_lower_bound bv lb with + | new_lb when Z.compare new_lb lb > 0 -> + (* lower bound increased; remove [lb, new_lb[ *) + remove ~ex acc + @@ Intervals.Int.Interval.of_bounds (Closed lb) (Open new_lb) + | new_lb -> + (* no change *) + assert (Z.equal new_lb lb); + acc + | exception Not_found -> + (* No value larger than lb matches the bit-pattern *) + remove ~ex acc + @@ Intervals.Int.Interval.of_bounds (Closed lb) Unbounded + in + let acc = + match Bitlist.decrease_upper_bound bv ub with + | new_ub when Z.compare new_ub ub < 0 -> + (* upper bound decreased; remove ]new_ub, ub] *) + remove ~ex acc + @@ Intervals.Int.Interval.of_bounds (Open new_ub) (Closed ub) + | new_ub -> + (* no change *) + assert (Z.equal new_ub ub); + acc + | exception Not_found -> + (* No value smaller than ub matches the bit-pattern *) + remove ~ex acc + @@ Intervals.Int.Interval.of_bounds Unbounded (Closed ub) + in + acc + ) !!int !!int + +let propagate_bitlist queue touched bcs dom = + let touch_c c = QC.push queue (Constraint c) in + let touch r = + HX.replace touched r (); + QC.push queue (Structural r); + Constraints.iter_parents touch_c r bcs + in + try + while true do + Bitlist_domains.Ephemeral.iter_changed touch dom; + Bitlist_domains.Ephemeral.clear_changed dom; + Any_constraint.propagate + Constraint.propagate_bitlist + Bitlist_domains.Ephemeral.structural_propagation + (QC.pop queue) dom + done + with QC.Empty -> () + +let propagate_intervals queue touched bcs dom = + let touch_c c = QC.push queue (Constraint c) in + let touch r = + HX.replace touched r (); + QC.push queue (Structural r); + Constraints.iter_parents touch_c r bcs + in + try + while true do + Interval_domains.Ephemeral.iter_changed touch dom; + Interval_domains.Ephemeral.clear_changed dom; + Any_constraint.propagate + Constraint.propagate_interval + Interval_domains.Ephemeral.structural_propagation + (QC.pop queue) dom + done + with QC.Empty -> () + +let propagate_all eqs bcs bdom idom = (* Call [simplify_pending] first because it can remove constraints from the pending set. *) let eqs, bcs = Constraints.simplify_pending eqs bcs in (* Optimization to avoid unnecessary allocations *) - if Constraints.has_pending bcs || Domains.has_changed dom then + let do_all = Constraints.has_pending bcs in + let do_bitlist = do_all || Bitlist_domains.has_changed bdom in + let do_intervals = do_all || Interval_domains.has_changed idom in + let do_any = do_bitlist || do_intervals in + if do_any then let queue = QC.create 17 in - let touch_c c = QC.push queue (Constraint c) in - Constraints.iter_pending touch_c bcs; + let touch_pending queue = + Constraints.iter_pending (fun c -> QC.push queue (Constraint c)) bcs + in + let bitlist_changed = HX.create 17 in + let touched = HX.create 17 in + let bdom = Bitlist_domains.edit bdom in + let idom = Interval_domains.edit idom in + + (* First, we propagate the pending constraints to both domains. Changes in + the bitlist domain are used to shrink the interval domains. *) + touch_pending queue; + propagate_bitlist queue touched bcs bdom; + assert (QC.is_empty queue); + + touch_pending queue; + HX.iter (fun r () -> + HX.replace bitlist_changed r (); + constrain_interval_from_bitlist + Interval_domains.Ephemeral.(handle idom r) + Bitlist_domains.Ephemeral.(!!(handle bdom r)) + ) touched; + HX.clear touched; + propagate_intervals queue touched bcs idom; + assert (QC.is_empty queue); + + (* Now the interval domain is stable, but the new intervals may have an + impact on the bitlist domains, so we must shrink them again when + applicable. We repeat until a fixpoint is reached. *) let bcs = Constraints.clear_pending bcs in - let changed = HX.create 17 in - let touch r = - HX.replace changed r (); - QC.push queue (Structural r); - Constraints.iter_parents touch_c r bcs + while HX.length touched > 0 do + HX.iter (fun r () -> + constrain_bitlist_from_interval + Bitlist_domains.Ephemeral.(handle bdom r) + Interval_domains.Ephemeral.(!!(handle idom r)) + ) touched; + HX.clear touched; + propagate_bitlist queue touched bcs bdom; + assert (QC.is_empty queue); + + HX.iter (fun r () -> + HX.replace bitlist_changed r (); + constrain_interval_from_bitlist + Interval_domains.Ephemeral.(handle idom r) + Bitlist_domains.Ephemeral.(!!(handle bdom r)) + ) touched; + HX.clear touched; + propagate_intervals queue touched bcs idom; + assert (QC.is_empty queue); + done; + + let eqs = + HX.fold (fun r () acc -> + let d = Bitlist_domains.Ephemeral.(!!(handle bdom r)) in + add_eqs acc (Shostak.Bitv.embed r) d + ) bitlist_changed eqs in - let dom = Domains.edit dom in - ( - try - while true do - Domains.Ephemeral.iter_changed touch dom; - Domains.Ephemeral.clear_changed dom; - Any_constraint.propagate (QC.pop queue) dom - done - with QC.Empty -> () - ); - HX.fold (fun r () acc -> - let d = Domains.Ephemeral.(!!(handle dom r)) in - add_eqs acc (Shostak.Bitv.embed r) d - ) changed eqs, bcs, Domains.snapshot dom + + eqs, bcs, Bitlist_domains.snapshot bdom, Interval_domains.snapshot idom else - eqs, bcs, dom + eqs, bcs, bdom, idom type t = { delayed : Rel_utils.Delayed.t @@ -487,13 +903,18 @@ let empty uf = { delayed = Rel_utils.Delayed.create ~is_ready:X.is_constant dispatch ; constraints = Constraints.empty ; size_splits = Q.one }, - Uf.GlobalDomains.add (module Domains) Domains.empty (Uf.domains uf) + Uf.GlobalDomains.add (module Bitlist_domains) Bitlist_domains.empty @@ + Uf.GlobalDomains.add (module Interval_domains) Interval_domains.empty @@ + Uf.domains uf let assume env uf la = let ds = Uf.domains uf in - let domain = Uf.GlobalDomains.find (module Domains) ds in + let domain = Uf.GlobalDomains.find (module Bitlist_domains) ds in + let int_domain = + Uf.GlobalDomains.find (module Interval_domains) ds + in let delayed, result = Rel_utils.Delayed.assume env.delayed uf la in - let (domain, constraints, eqs, size_splits) = + let (domain, int_domain, constraints, eqs, size_splits) = try let (constraints, eqs, size_splits) = List.fold_left (fun (bcs, eqs, ss) (a, _root, ex, orig) -> @@ -511,6 +932,18 @@ let assume env uf la = | L.Eq (rr, nrr), Subst when is_bv_r rr -> let bcs = Constraints.subst ~ex rr nrr bcs in (bcs, eqs, ss) + | Builtin (is_true, BVULE, [x; y]), _ -> + let x, exx = Uf.find_r uf x in + let y, exy = Uf.find_r uf y in + let ex = Ex.union ex @@ Ex.union exx exy in + let c = + if is_true then + Constraint.bvule x y + else + Constraint.bvugt x y + in + let bcs = Constraints.add ~ex c bcs in + (bcs, eqs, ss) | L.Distinct (false, [rr; nrr]), _ when is_1bit rr -> (* We don't (yet) support [distinct] in general, but we must support it for case splits to avoid looping. @@ -526,17 +959,22 @@ let assume env uf la = (env.constraints, [], env.size_splits) la in - let eqs, constraints, domain = propagate eqs constraints domain in + let eqs, constraints, domain, int_domain = + propagate_all eqs constraints domain int_domain + in if Options.get_debug_bitv () && not (Lists.is_empty eqs) then ( Printer.print_dbg ~module_name:"Bitv_rel" ~function_name:"assume" - "bitlist domain: @[%a@]" Domains.pp domain; + "bitlist domain: @[%a@]" Bitlist_domains.pp domain; + Printer.print_dbg + ~module_name:"Bitv_rel" ~function_name:"assume" + "interval domain: @[%a@]" Interval_domains.pp int_domain; Printer.print_dbg ~module_name:"Bitv_rel" ~function_name:"assume" "bitlist constraints: @[%a@]" Constraints.pp constraints; ); - (domain, constraints, eqs, size_splits) - with Bitlist.Inconsistent ex -> + (domain, int_domain, constraints, eqs, size_splits) + with Bitlist.Inconsistent ex | Interval_domain.Inconsistent ex -> raise @@ Ex.Inconsistent (ex, Uf.cl_extract uf) in let assume = @@ -546,7 +984,8 @@ let assume env uf la = { result with assume = List.rev_append assume result.assume } in { delayed ; constraints ; size_splits }, - Uf.GlobalDomains.add (module Domains) domain ds, + Uf.GlobalDomains.add (module Bitlist_domains) domain @@ + Uf.GlobalDomains.add (module Interval_domains) int_domain ds, result let query _ _ _ = None @@ -555,7 +994,9 @@ let case_split env uf ~for_model = if not for_model && Stdlib.(env.size_splits >= Options.get_max_split ()) then [] else - let domain = Uf.GlobalDomains.find (module Domains) (Uf.domains uf) in + let domain = + Uf.GlobalDomains.find (module Bitlist_domains) (Uf.domains uf) + in (* Look for representatives with minimal, non-fully known, domain size. We first look among the constrained variables, then if there are no @@ -579,7 +1020,7 @@ let case_split env uf ~for_model = match bv with | Bitv.Cte _ -> acc | Other r | Ext (r, _, _, _) -> - let bl = Domains.get r.value domain in + let bl = Bitlist_domains.get r.value domain in f_acc r.value bl acc ) acc (Shostak.Bitv.embed r) in @@ -587,14 +1028,14 @@ let case_split env uf ~for_model = match Constraints.fold_args f_acc' env.constraints None with | Some (nunk, xs) -> nunk, xs | _ -> - match Domains.fold_leaves f_acc domain None with + match Bitlist_domains.fold_leaves f_acc domain None with | Some (nunk, xs) -> nunk, xs | None -> 0, SX.empty in (* For now, just pick a value for the most significant bit. *) match SX.choose candidates with | r -> - let bl = Domains.get r domain in + let bl = Bitlist_domains.get r domain in let w = Bitlist.width bl in let unknown = Z.extract (Z.lognot @@ Bitlist.bits_known bl) 0 w in let bitidx = Z.numbits unknown - 1 in @@ -615,11 +1056,8 @@ let add env uf r t = let delayed, eqs = Rel_utils.Delayed.add env.delayed uf r t in let env, eqs = if is_bv_r r then - try - let constraints = extract_constraints env.constraints uf r t in - { env with constraints }, eqs - with Domains.Inconsistent ex -> - raise @@ Ex.Inconsistent (ex, Uf.cl_extract uf) + let constraints = extract_constraints env.constraints uf r t in + { env with constraints }, eqs else env, eqs in @@ -636,3 +1074,7 @@ let assume_th_elt t th_elt _ = | Util.Bitv -> failwith "This Theory does not support theories extension" | _ -> t + +module Test = struct + let shared_msb = shared_msb +end diff --git a/src/lib/reasoners/bitv_rel.mli b/src/lib/reasoners/bitv_rel.mli index cab8616477..d3946e42df 100644 --- a/src/lib/reasoners/bitv_rel.mli +++ b/src/lib/reasoners/bitv_rel.mli @@ -27,4 +27,8 @@ include Sig_rel.RELATION +(**/**) +module Test : sig + val shared_msb : int -> Z.t -> Z.t -> int +end diff --git a/src/lib/reasoners/intervals.ml b/src/lib/reasoners/intervals.ml index f83ee3ce5c..bc012ae6a1 100644 --- a/src/lib/reasoners/intervals.ml +++ b/src/lib/reasoners/intervals.ml @@ -47,6 +47,11 @@ module Log = struct ) end +let map_bound f = function + | Unbounded -> Unbounded + | Open x -> Open (f x) + | Closed x -> Closed (f x) + module Ring(C : Core)(RT : RingType) = struct include C.Union(RT) @@ -82,6 +87,16 @@ module Ring(C : Core)(RT : RingType) = struct trace2 "add" u1 u2 @@ of_set_nonempty @@ map2_mon_to_set RT.add Inc u1 Inc u2 + let scale alpha u = + let alpha = RT.finite alpha in + let c = RT.compare alpha RT.zero in + if c < 0 then + map_strict_dec (RT.mul alpha) u + else if c > 0 then + map_strict_inc (RT.mul alpha) u + else + invalid_arg "scale: cannot scale by zero" + let mul u1 u2 = trace2 "mul" u1 u2 @@ of_set_nonempty @@ trisection_map_to_set RT.zero u1 @@ -289,6 +304,11 @@ module ZEuclideanType = struct | Finite x, Neg_infinite -> if Z.sign x < 0 then Finite Z.one else Finite Z.zero + + let lognot = function + | Neg_infinite -> Pos_infinite + | Pos_infinite -> Neg_infinite + | Finite n -> Finite (Z.lognot n) end (* AlgebraicType interface for reals @@ -571,7 +591,44 @@ type 'a union = 'a Core.union module Real = AlgebraicField(Core)(QAlgebraicType) -module Int = EuclideanRing(Core)(ZEuclideanType) +module Int = struct + include EuclideanRing(Core)(ZEuclideanType) + + let extract u ~ofs ~len = + if ofs < 0 || len <= 0 then invalid_arg "extract"; + trace1 (Fmt.str "extract ~ofs:%d ~len:%d" ofs len) u @@ + let max_val = Z.extract Z.minus_one 0 len in + let full = Interval.of_bounds (Closed Z.zero) (Closed max_val) in + of_set_nonempty @@ + map_to_set (fun i -> + match i.lb, i.ub with + | ZEuclideanType.Neg_infinite, _ | _, Pos_infinite -> + interval_set full + | _, Neg_infinite | Pos_infinite, _ -> + assert false + | Finite lb, Finite ub -> + let lb = Z.shift_right lb ofs in + let ub = Z.shift_right ub ofs in + if Z.(numbits (ub - lb)) <= len then + (* The image spans an interval of length at most [len] *) + let lb_mod = Z.extract lb 0 len in + let ub_mod = Z.extract ub 0 len in + if Z.(compare lb_mod ub_mod) <= 0 then + interval_set @@ Interval.of_bounds (Closed lb_mod) (Closed ub_mod) + else + union_set + (interval_set @@ Interval.of_bounds + (Closed Z.zero) (Closed ub_mod)) + (interval_set @@ Interval.of_bounds + (Closed lb_mod) (Closed max_val)) + else + (* The image is too large; all values are possible. *) + interval_set full + ) u + + let lognot u = + trace1 "lognot" u @@ map_strict_dec ZEuclideanType.lognot u +end module Legacy = struct type t = Real of Real.t | Int of Int.t @@ -775,11 +832,6 @@ module Legacy = struct in (lb, ub) - let map_bound f = function - | Unbounded -> Unbounded - | Open x -> Open (f x) - | Closed x -> Closed (f x) - let lower_bound = function | Real u -> Real.lower_bound u | Int u -> diff --git a/src/lib/reasoners/intervals.mli b/src/lib/reasoners/intervals.mli index 39fb279606..72d1b10cda 100644 --- a/src/lib/reasoners/intervals.mli +++ b/src/lib/reasoners/intervals.mli @@ -27,6 +27,10 @@ open Intervals_intf +val map_bound : ('a -> 'b) -> 'a bound -> 'b bound +(** [map_bound f b] applies [f] to a finite (open or closed) bound [b] and + does not change an unbounded bound. *) + (** This module provides implementations of union-of-intervals over reals and integers. *) @@ -40,11 +44,36 @@ module Real : AlgebraicField and type 'a union = 'a union (** Union-of-intervals over real numbers. *) -module Int : EuclideanRing - with type explanation := Explanation.t - and type value := Z.t - and type 'a union = 'a union (** Union-of-intervals over integers. *) +module Int : sig + include EuclideanRing + with type explanation := Explanation.t + and type value := Z.t + and type 'a union = 'a union + + (** {2 Bit-vector helpers} + + These functions are intended for the BV theory. They can only be used + with integer intervals. Some of these functions return intervals "of + width [n]", where [n] is computed from the parameters of the + function. This means that the returned interval is contained in the + range [[0, n)] ([0] inclusive, [n] exclusive). *) + + val lognot : t -> t + (** Bitwise logical negation. [lognot u] always returns [-u - 1]. *) + + val extract : t -> ofs:int -> len:int -> t + (** [extract s i j] returns the bits of [s] from position [i] to [j], + inclusive. + + Represents the function [fun x -> floor(x / 2^i) % 2^(j - i + 1)]. + + Requires [0 <= i <= j] and returns an interval of width [j - i + 1]. + + {b Note}: The interval [s] must be an integer interval, but is + allowed to be unbounded (in which case [extract s i j] returns the + full interval [[0, 2^(j - i + 1) - 1]]). *) +end module Legacy : sig (** The [Legacy] module reimplements (most of) the old legacy [Intervals] diff --git a/src/lib/reasoners/intervals_intf.ml b/src/lib/reasoners/intervals_intf.ml index 3cfe985a50..703444fb64 100644 --- a/src/lib/reasoners/intervals_intf.ml +++ b/src/lib/reasoners/intervals_intf.ml @@ -650,6 +650,12 @@ module type Ring = sig (** [add u1 u2] evaluates to {m \{ x + y \mid x \in S_1, y \in S_2 \}} when [u1] evaluates to {m S_1} and [u2] evaluates to {m S_2}. *) + val scale : value -> t -> t + (** [scale v u] evaluates to {m \{ v \times x \mid x \in S \}} when [u] + evaluates to {m S}. + + @raise Invalid_argument if [v] is zero. *) + val mul : t -> t -> t (** [mul u1 u2] evaluates to {m \{ x \times y \mid x \in S_1, y \in S_2 \}} when [u1] evaluates to {m S_1} and [u2] evaluates to {m S_2}. *) diff --git a/src/lib/reasoners/rel_utils.ml b/src/lib/reasoners/rel_utils.ml index 272fac045f..854077ff20 100644 --- a/src/lib/reasoners/rel_utils.ml +++ b/src/lib/reasoners/rel_utils.ml @@ -267,8 +267,8 @@ module type Domain = sig @raise Inconsistent if [r] cannot possibly be in the domain of [t]. *) - val map_leaves : (X.r -> 'a -> t) -> X.r -> 'a -> t - (** [map_leaves f r acc] is the "inverse" of [fold_leaves] in the sense that + val map_leaves : (X.r -> t) -> X.r -> t + (** [map_leaves f r] is the "inverse" of [fold_leaves] in the sense that it rebuilds a domain for [r] by using [f] to access the domain for each of [r]'s leaves. *) end @@ -403,9 +403,9 @@ struct ) leaves_map (X.leaves r) let create_domain r = - Domain.map_leaves (fun r () -> + Domain.map_leaves (fun r -> Domain.unknown (X.type_info r) - ) r () + ) r let add r t = if MX.mem r t.domains then t else @@ -581,7 +581,7 @@ struct if X.is_a_leaf parent then assert (X.equal r parent) else - update parent (Domain.map_leaves (fun r () -> get r) parent ()) + update parent (Domain.map_leaves get parent) ) parents | exception Not_found -> () else diff --git a/src/lib/structures/expr.ml b/src/lib/structures/expr.ml index 871955304e..0f18b80118 100644 --- a/src/lib/structures/expr.ml +++ b/src/lib/structures/expr.ml @@ -377,6 +377,12 @@ module SmtPrinter = struct | Sy.L_neg_built Sy.LT, [a; b] -> Fmt.pf ppf "@[<2>(>= %a %a@])" pp a pp b + | Sy.L_built Sy.BVULE, [a;b] -> + Fmt.pf ppf "@[<2>(bvule %a %a@])" pp a pp b + + | Sy.L_neg_built Sy.BVULE, [a;b] -> + Fmt.pf ppf "@[<2>(bvugt %a %a@])" pp a pp b + | Sy.L_neg_pred, [a] -> Fmt.pf ppf "@[<2>(not@ %a@])" pp a @@ -387,7 +393,8 @@ module SmtPrinter = struct Fmt.pf ppf "(not @[<2>((_ is %a)@ %a@]))" Uid.pp hs pp e - | (Sy.L_built (Sy.LT | Sy.LE) | Sy.L_neg_built (Sy.LT | Sy.LE) + | (Sy.L_built (Sy.LT | Sy.LE | Sy.BVULE) + | Sy.L_neg_built (Sy.LT | Sy.LE | Sy.BVULE) | Sy.L_neg_pred | Sy.L_eq | Sy.L_neg_eq | Sy.L_built (Sy.IsConstr _) | Sy.L_neg_built (Sy.IsConstr _)), _ -> @@ -559,6 +566,12 @@ module AEPrinter = struct | Sy.L_neg_built Sy.LT, [a; b] -> Fmt.pf ppf "(%a >= %a)" pp a pp b + | Sy.L_built Sy.BVULE, [a;b] -> + Fmt.pf ppf "(%a <= %a)" pp a pp b + + | Sy.L_neg_built Sy.BVULE, [a;b] -> + Fmt.pf ppf "(%a > %a)" pp a pp b + | Sy.L_neg_pred, [a] -> Fmt.pf ppf "(not %a)" pp a @@ -568,7 +581,8 @@ module AEPrinter = struct | Sy.L_neg_built (Sy.IsConstr hs), [e] -> Fmt.pf ppf "not (%a ? %a)" pp e Uid.pp hs - | (Sy.L_built (Sy.LT | Sy.LE) | Sy.L_neg_built (Sy.LT | Sy.LE) + | (Sy.L_built (Sy.LT | Sy.LE | Sy.BVULE) + | Sy.L_neg_built (Sy.LT | Sy.LE | Sy.BVULE) | Sy.L_neg_pred | Sy.L_eq | Sy.L_neg_eq | Sy.L_built (Sy.IsConstr _) | Sy.L_neg_built (Sy.IsConstr _)), _ -> @@ -3181,8 +3195,8 @@ module BV = struct (bvnot (bvlshr (bvnot s) t)) (* Comparisons *) - let bvult s t = Ints.(bv2nat s < bv2nat t) - let bvule s t = Ints.(bv2nat s <= bv2nat t) + let bvult s t = mk_builtin ~is_pos:false BVULE [t; s] + let bvule s t = mk_builtin ~is_pos:true BVULE [s; t] let bvugt s t = bvult t s let bvuge s t = bvule t s let bvslt s t = diff --git a/src/lib/structures/symbols.ml b/src/lib/structures/symbols.ml index b463012070..b4b591a3cc 100644 --- a/src/lib/structures/symbols.ml +++ b/src/lib/structures/symbols.ml @@ -28,6 +28,7 @@ type builtin = LE | LT (* arithmetic *) | IsConstr of Uid.t (* ADT tester *) + | BVULE (* unsigned bit-vector arithmetic *) type operator = | Tite @@ -197,7 +198,7 @@ let compare_builtin b1 b2 = Util.compare_algebraic b1 b2 (function | IsConstr h1, IsConstr h2 -> Uid.compare h1 h2 - | _, (LT | LE | IsConstr _) -> assert false + | _, (LT | LE | BVULE | IsConstr _) -> assert false ) let compare_lits lit1 lit2 = @@ -371,6 +372,8 @@ module AEPrinter = struct | L_neg_built LE -> Fmt.pf ppf ">" | L_neg_built LT -> Fmt.pf ppf ">=" | L_neg_pred -> Fmt.pf ppf "not " + | L_built BVULE -> Fmt.pf ppf "<=" + | L_neg_built BVULE -> Fmt.pf ppf ">" | L_built (IsConstr h) -> Fmt.pf ppf "? %a" Uid.pp h | L_neg_built (IsConstr h) -> diff --git a/src/lib/structures/symbols.mli b/src/lib/structures/symbols.mli index 12885e2368..4b28cae40e 100644 --- a/src/lib/structures/symbols.mli +++ b/src/lib/structures/symbols.mli @@ -28,6 +28,7 @@ type builtin = LE | LT (* arithmetic *) | IsConstr of Uid.t (* ADT tester *) + | BVULE (* unsigned bit-vector arithmetic *) type operator = | Tite diff --git a/src/lib/structures/xliteral.ml b/src/lib/structures/xliteral.ml index 6ebfcbd9dc..5862e5b685 100644 --- a/src/lib/structures/xliteral.ml +++ b/src/lib/structures/xliteral.ml @@ -28,6 +28,7 @@ type builtin = Symbols.builtin = LE | LT | (* arithmetic *) IsConstr of Uid.t (* ADT tester *) + | BVULE (* unsigned bit-vector arithmetic *) type 'a view = | Eq of 'a * 'a @@ -116,6 +117,15 @@ let print_view ?(lbl="") pr_elt fmt vw = | Builtin (_, (LE | LT), _) -> assert false (* not reachable *) + | Builtin (true, BVULE, [v1;v2]) -> + Format.fprintf fmt "%s %a <= %a" lbl pr_elt v1 pr_elt v2 + + | Builtin (false, BVULE, [v1;v2]) -> + Format.fprintf fmt "%s %a > %a" lbl pr_elt v1 pr_elt v2 + + | Builtin (_, BVULE, _) -> + assert false (* not reachable *) + | Builtin (pos, IsConstr hs, [e]) -> Format.fprintf fmt "%s(%a ? %a)" (if pos then "" else "not ") pr_elt e Uid.pp hs diff --git a/src/lib/structures/xliteral.mli b/src/lib/structures/xliteral.mli index 56bd0fdff1..7f83f90b20 100644 --- a/src/lib/structures/xliteral.mli +++ b/src/lib/structures/xliteral.mli @@ -28,6 +28,7 @@ type builtin = Symbols.builtin = LE | LT | (* arithmetic *) IsConstr of Uid.t (* ADT tester *) + | BVULE (* unsigned bit-vector arithmetic *) type 'a view = (*private*) | Eq of 'a * 'a diff --git a/tests/bitv/bitlist-interval001.expected b/tests/bitv/bitlist-interval001.expected new file mode 100644 index 0000000000..6f99ff0f44 --- /dev/null +++ b/tests/bitv/bitlist-interval001.expected @@ -0,0 +1,2 @@ + +unsat diff --git a/tests/bitv/bitlist-interval001.smt2 b/tests/bitv/bitlist-interval001.smt2 new file mode 100644 index 0000000000..0d860da43d --- /dev/null +++ b/tests/bitv/bitlist-interval001.smt2 @@ -0,0 +1,5 @@ +(set-logic ALL) +(declare-const x (_ BitVec 1024)) +(assert (bvult x (_ bv3 1024))) +(assert (= (bvand x (_ bv3 1024)) (_ bv3 1024))) +(check-sat) diff --git a/tests/bitvec_tests.ml b/tests/bitvec_tests.ml new file mode 100644 index 0000000000..e5b7c1b6d8 --- /dev/null +++ b/tests/bitvec_tests.ml @@ -0,0 +1,231 @@ +open AltErgoLib +open QCheck2 + +module IntSet : sig + type t + + val subset : t -> t -> bool + + val of_fold : ((Z.t -> t -> t) -> 'a -> t -> t) -> 'a -> t + + val map : (Z.t -> Z.t) -> t -> t + + val mem : Z.t -> t -> bool +end = struct + type t = Z.t + + let empty = Z.zero + + let is_empty n = Z.equal n empty + + let singleton n = + match Z.to_int n with + | exception Z.Overflow -> invalid_arg "IntSet.singleton" + | n -> Z.(one lsl n) + + let union = Z.(lor) + + let add n = union (singleton n) + + let subtract big small = Z.(big land (lognot small)) + + let subset small big = is_empty (subtract small big) + + let mem n t = + match Z.to_int n with + | exception Z.Overflow -> invalid_arg "IntSet.mem" + | n -> Z.testbit t n + + let fold f = + let rec aux ofs n acc = + let tz = Z.trailing_zeros n in + if tz <> max_int then + aux (ofs + tz + 1) Z.(n asr Stdlib.(tz + 1)) + (f (Z.of_int (ofs + tz)) acc) + else + acc + in + aux 0 + + let of_fold fold elt = fold add elt empty + + let map f t = + fold (fun n -> add (f n)) t empty +end + +let finite_bound = function + | Intervals_intf.Unbounded | Open _ -> assert false + | Closed n -> n + +let fold_finite_domain f int acc = + Intervals.Int.fold (fun acc i -> + let { Intervals_intf.lb ; ub } = Intervals.Int.Interval.view i in + let lb = finite_bound lb and ub = finite_bound ub in + let maxi = Z.(to_int @@ ub - lb) in + let acc = ref acc in + for i = 0 to maxi do + acc := f Z.(lb + ~$i) !acc + done; + !acc + ) acc int + +let of_interval = + IntSet.of_fold fold_finite_domain + +let of_bitlist = + IntSet.of_fold Bitlist.fold_domain + +(* Generator for bit-vectors of size sz *) +let bitvec sz = Gen.(int_bound sz >|= Z.of_int) + +(* Generator for a single interval bound *) +let interval sz = + assert (sz < Sys.int_size); + Gen.( + int_range 0 (1 lsl sz - 1) >>= fun lb -> + int_range lb (1 lsl sz - 1) >>= fun ub -> + return + @@ Intervals.Int.interval_set + @@ Intervals.Int.Interval.of_bounds + (Closed (Z.of_int lb)) (Closed (Z.of_int ub)) + ) + +(* Generator for an union of (possibly overlapping) intervals *) +let intervals sz = + let open Gen in + let rec intervals n = + if n <= 0 then + interval sz + else + let* i1 = interval sz in + let* i2 = intervals (n - 1) in + return @@ Intervals.Int.union_set i1 i2 + in + small_nat >>= intervals >|= Intervals.Int.of_set_nonempty + +(* Generator for a bitlist *) +let bitlist sz = + assert (sz < Sys.int_size); + let open Gen in + let rec bitlist sz = + if sz <= 0 then + return (Z.zero, Z.zero) + else + let* known = bool in + if known then + let* (set_bits, clr_bits) = bitlist (sz - 1) in + let mask = Z.shift_left Z.one (sz - 1) in + let* is_set = bool in + if is_set then + return (Z.logor set_bits mask, clr_bits) + else + return (set_bits, Z.logor clr_bits mask) + else + bitlist (sz - 1) + in + let* (set_bits, clr_bits) = bitlist sz in + let set_bits = + Bitlist.ones @@ Bitlist.exact sz set_bits Explanation.empty + in + let clr_bits = + Bitlist.zeroes @@ Bitlist.exact sz (Z.lognot clr_bits) Explanation.empty + in + return @@ Bitlist.intersect set_bits clr_bits + +(* Generator for extraction indices *) +let subvec sz = Gen.( + int_range 0 (sz - 1) >>= fun i -> + int_range i (sz - 1) >>= fun j -> + return (i, j) + ) + +(* Check that [Intervals.extract] computes a correct overapproximation of + the set of values obtained by the [extract] smt-lib function. *) +let test_extract sz = + Test.make ~count:1_000 + ~print:Print.(pair (Fmt.to_to_string Intervals.Int.pp) (pair int int)) + Gen.(pair (intervals sz) (subvec sz)) + (fun (t, (i, j)) -> + IntSet.subset + (IntSet.map (fun n -> Z.extract n i (j - i + 1)) (of_interval t)) + (of_interval (Intervals.Int.extract t ~ofs:i ~len:(j - i + 1)))) + +let () = + Test.check_exn (test_extract 3) + +(* Check that [shared_msb w z1 z2] returns exactly the number of most + significant bits that are shared between [z1] and [z2]. *) +let test_shared_msb sz = + Test.make ~count:1_000 + ~print:Print.( + pair (Fmt.to_to_string Z.pp_print) (Fmt.to_to_string Z.pp_print)) + Gen.(pair (bitvec sz) (bitvec sz)) + (fun (lb, ub) -> + let lb, ub = + if Z.compare lb ub > 0 then ub, lb else lb, ub + in + let shared = Bitv_rel.Test.shared_msb sz lb ub in + Z.equal + (Z.shift_right lb (sz - shared)) + (Z.shift_right ub (sz - shared)) && + (shared = sz || not @@ Z.equal + (Z.shift_right lb (sz - shared - 1)) + (Z.shift_right ub (sz - shared - 1)))) + +let () = + Test.check_exn (test_shared_msb 3) + +let all_range start_ end_ f = + try + for i = start_ to end_ do + if not (f i) then + raise Exit + done; + true + with Exit -> false + +(* Check that [increase_lower_bound b n] returns the smallest value larger + than [n] that matches [b]. Also check that if the [Not_found] exception is + raised, no such value exists. *) +let test_increase_lower_bound sz = + Test.make ~count:1_000 + ~print:Print.( + pair (Fmt.to_to_string Bitlist.pp) (Fmt.to_to_string Z.pp_print)) + Gen.(pair (bitlist sz) (bitvec sz)) + (fun (bl, z) -> + let set = of_bitlist bl in + match Bitlist.increase_lower_bound bl z with + | new_lb -> + Z.numbits new_lb <= sz && + IntSet.mem new_lb set && + all_range (Z.to_int z) (Z.to_int new_lb - 1) @@ fun i -> + not (IntSet.mem (Z.of_int i) set) + | exception Not_found -> + all_range (Z.to_int z) (sz - 1) @@ fun i -> + not (IntSet.mem (Z.of_int i) set)) + +let () = + Test.check_exn (test_increase_lower_bound 3) + +(* Check that [decrease_upper_bound] returns the largest value smaller than [n] + that matches [b]. Also check that if the [Not_found] exception is raised, no + such value exists. *) +let test_decrease_upper_bound sz = + Test.make ~count:1_000 + ~print:Print.( + pair (Fmt.to_to_string Bitlist.pp) (Fmt.to_to_string Z.pp_print)) + Gen.(pair (bitlist sz) (bitvec sz)) + (fun (bl, z) -> + let set = of_bitlist bl in + match Bitlist.decrease_upper_bound bl z with + | new_ub -> + Z.numbits new_ub <= sz && + IntSet.mem new_ub set && + all_range (Z.to_int new_ub + 1) (Z.to_int z) @@ fun i -> + not (IntSet.mem (Z.of_int i) set) + | exception Not_found -> + all_range 0 (Z.to_int z) @@ fun i -> + not (IntSet.mem (Z.of_int i) set)) + +let () = + Test.check_exn (test_decrease_upper_bound 3) diff --git a/tests/dune b/tests/dune index 34b0c7762e..a07748c073 100644 --- a/tests/dune +++ b/tests/dune @@ -24,3 +24,9 @@ (rule (alias runtest-ci) (action (diff dune.inc dune.inc.gen))) + +(test + (package alt-ergo-lib) + (name bitvec_tests) + (modules bitvec_tests) + (libraries alt-ergo-lib qcheck zarith)) diff --git a/tests/dune.inc b/tests/dune.inc index 3005afc12d..99478732c7 100644 --- a/tests/dune.inc +++ b/tests/dune.inc @@ -128715,7 +128715,256 @@ (alias runtest-quick) (package alt-ergo) (action - (diff testfile-array-cs.dolmen.expected testfile-array-cs.dolmen_dolmen.output)))) + (diff testfile-array-cs.dolmen.expected testfile-array-cs.dolmen_dolmen.output))) + (rule + (target bitlist-interval001_ci_cdcl_no_minimal_bj.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --frontend dolmen + --sat-solver CDCL + --no-minimal-bj + %{input}))))))) + (rule + (deps bitlist-interval001_ci_cdcl_no_minimal_bj.output) + (alias runtest-ci) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_ci_cdcl_no_minimal_bj.output))) + (rule + (target bitlist-interval001_ci_cdcl_tableaux_no_minimal_bj_no_tableaux_cdcl_in_theories_and_instantiation.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --frontend dolmen + --sat-solver CDCL-Tableaux + --no-minimal-bj + --no-tableaux-cdcl-in-theories + --no-tableaux-cdcl-in-instantiation + %{input}))))))) + (rule + (deps bitlist-interval001_ci_cdcl_tableaux_no_minimal_bj_no_tableaux_cdcl_in_theories_and_instantiation.output) + (alias runtest-ci) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_ci_cdcl_tableaux_no_minimal_bj_no_tableaux_cdcl_in_theories_and_instantiation.output))) + (rule + (target bitlist-interval001_ci_cdcl_tableaux_no_tableaux_cdcl_in_theories_and_instantiation.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --frontend dolmen + --sat-solver CDCL-Tableaux + --no-tableaux-cdcl-in-theories + --no-tableaux-cdcl-in-instantiation + %{input}))))))) + (rule + (deps bitlist-interval001_ci_cdcl_tableaux_no_tableaux_cdcl_in_theories_and_instantiation.output) + (alias runtest-ci) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_ci_cdcl_tableaux_no_tableaux_cdcl_in_theories_and_instantiation.output))) + (rule + (target bitlist-interval001_ci_no_tableaux_cdcl_in_instantiation.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --frontend dolmen + --sat-solver CDCL-Tableaux + --no-tableaux-cdcl-in-instantiation + %{input}))))))) + (rule + (deps bitlist-interval001_ci_no_tableaux_cdcl_in_instantiation.output) + (alias runtest-ci) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_ci_no_tableaux_cdcl_in_instantiation.output))) + (rule + (target bitlist-interval001_ci_cdcl_tableaux_no_tableaux_cdcl_in_theories.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --frontend dolmen + --sat-solver CDCL-Tableaux + --no-tableaux-cdcl-in-theories + %{input}))))))) + (rule + (deps bitlist-interval001_ci_cdcl_tableaux_no_tableaux_cdcl_in_theories.output) + (alias runtest-ci) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_ci_cdcl_tableaux_no_tableaux_cdcl_in_theories.output))) + (rule + (target bitlist-interval001_ci_tableaux_cdcl_no_minimal_bj.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --frontend dolmen + --sat-solver CDCL-Tableaux + --no-minimal-bj + %{input}))))))) + (rule + (deps bitlist-interval001_ci_tableaux_cdcl_no_minimal_bj.output) + (alias runtest-ci) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_ci_tableaux_cdcl_no_minimal_bj.output))) + (rule + (target bitlist-interval001_cdcl.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --frontend dolmen + --sat-solver CDCL + %{input}))))))) + (rule + (deps bitlist-interval001_cdcl.output) + (alias runtest-quick) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_cdcl.output))) + (rule + (target bitlist-interval001_tableaux_cdcl.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --frontend dolmen + --sat-solver Tableaux-CDCL + %{input}))))))) + (rule + (deps bitlist-interval001_tableaux_cdcl.output) + (alias runtest-quick) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_tableaux_cdcl.output))) + (rule + (target bitlist-interval001_tableaux.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --frontend dolmen + --sat-solver Tableaux + %{input}))))))) + (rule + (deps bitlist-interval001_tableaux.output) + (alias runtest-quick) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_tableaux.output))) + (rule + (target bitlist-interval001_dolmen.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --frontend dolmen + %{input}))))))) + (rule + (deps bitlist-interval001_dolmen.output) + (alias runtest-quick) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_dolmen.output))) + (rule + (target bitlist-interval001_fpa.output) + (deps (:input bitlist-interval001.smt2)) + (package alt-ergo) + (action + (chdir %{workspace_root} + (with-stdout-to %{target} + (ignore-stderr + (with-accepted-exit-codes (or 0) + (run %{bin:alt-ergo} + --timelimit=2 + --enable-assertions + --output=smtlib2 + --enable-theories fpa + %{input}))))))) + (rule + (deps bitlist-interval001_fpa.output) + (alias runtest-quick) + (package alt-ergo) + (action + (diff bitlist-interval001.expected bitlist-interval001_fpa.output)))) ; Auto-generated part end ; File auto-generated by gentests