Skip to content

Commit

Permalink
use int for cs
Browse files Browse the repository at this point in the history
  • Loading branch information
bclement-ocp committed Jul 24, 2024
1 parent d8a18fc commit b8262c0
Showing 1 changed file with 146 additions and 37 deletions.
183 changes: 146 additions & 37 deletions src/lib/reasoners/bitv_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ module Constraint : sig

type rel_t =
| Rbinrel of binrel * X.r * X.r
| Rdistinct of X.r array

type view =
| Cfun of X.r * fun_t
Expand Down Expand Up @@ -275,6 +276,8 @@ module Constraint : sig
val bvlshr : X.r -> X.r -> X.r -> t
(** [bvshl r x y] is the constraint [r = x >> y] *)

val distinct : X.r list -> t

val bvule : X.r -> X.r -> t

val bvugt : X.r -> X.r -> t
Expand Down Expand Up @@ -364,22 +367,32 @@ end = struct

type rel_t =
| Rbinrel of binrel * X.r * X.r
| Rdistinct of X.r array

let pp_rel_t ppf = function
| Rbinrel (op, x, y) ->
Fmt.pf ppf "%a@[(%a,@ %a)@]" pp_binrel op X.print x X.print y
| Rdistinct rs ->
Fmt.pf ppf "distinct@[(%a@])" (Fmt.array ~sep:Fmt.comma X.print) rs

let equal_rel_t f1 f2 =
match f1, f2 with
| Rbinrel (op1, x1, y1), Rbinrel (op2, x2, y2) ->
equal_binrel op1 op2 && X.equal x1 x2 && X.equal y1 y2
| Rbinrel _, _ | _, Rbinrel _ -> false
| Rdistinct rs1, Rdistinct rs2 ->
Array.length rs1 = Array.length rs2 &&
Array.for_all2 X.equal rs1 rs2

let hash_rel_t = function
| Rbinrel (op, x, y) -> Hashtbl.hash (hash_binrel op, X.hash x, X.hash y)
| Rdistinct rs ->
Array.fold_left (fun acc r -> Hashtbl.hash (acc, X.hash r)) 0 rs

let normalize_rel_t = function
(* No normalization for relations yet *)
| r -> r
| Rbinrel _ as r -> r
| Rdistinct rs as r ->
Array.sort X.hash_cmp rs; r

type view =
| Cfun of X.r * fun_t
Expand Down Expand Up @@ -453,6 +466,7 @@ end = struct

let crel r = hcons @@ Crel r

let distinct rs = crel (Rdistinct (Array.of_list rs))
let cbinrel op x y = crel (Rbinrel (op, x, y))

let bvule = cbinrel Rule
Expand All @@ -471,6 +485,7 @@ end = struct
let fold_args_rel_t f r acc =
match r with
| Rbinrel (_op, x, y) -> f y (f x acc)
| Rdistinct rs -> Array.fold_right f rs acc

let fold_args_view f c acc =
match c with
Expand Down Expand Up @@ -773,17 +788,82 @@ end = struct
| Rugt ->
propagate_less_than ~ex ~strict:true dy dx

module HZ = Hashtbl.Make(Z)

let propagate_interval_distinct ~ex dom rs =
let open Rel_utils.HandleNotations(Interval_domain)(Interval_domains_uf) in
let get r = Interval_domains_uf.entry dom r in
let remove ~ex:ex' d v =
update ~ex:Ex.empty d @@
match
Intervals.Int.of_complement ~ex:(Ex.union ex ex')
(Intervals.Int.Interval.singleton v)
with
| Empty _ -> assert false
| NonEmpty u -> u
in
match rs with
| [| |] | [| _ |] -> ()
| [| x; y |] -> (
let dx = get x and dy = get y in
match Intervals.Int.value_opt !!dx, Intervals.Int.value_opt !!dy with
| Some (vx, ex_x), Some (vy, ex_y) ->
if Z.equal vx vy then
raise @@
Interval_domain.Inconsistent (Ex.union ex_x ex_y)
| Some (vx, ex_x), None ->
remove ~ex:ex_x dy vx
| None, Some (vy, ex_y) ->
remove ~ex:ex_y dx vy
| None, None -> ()
)
| _ ->
let seen = HZ.create 17 in
let table = HX.create (Array.length rs) in
Array.iter (fun r -> HX.replace table r (get r)) rs;
let rec loop () =
let to_filter = ref [] in
HX.filter_map_inplace (fun _ dr ->
match Intervals.Int.value_opt !!dr with
| Some (v, ex_r) ->
begin match HZ.find seen v with
| ex_r' ->
raise @@
Interval_domain.Inconsistent
(Ex.union ex (Ex.union ex_r ex_r'))
| exception Not_found ->
HZ.add seen v ex_r;
to_filter := (v, ex_r) :: !to_filter
end;
None
| None -> Some dr
) table;
match !to_filter with
| [] -> ()
| to_filter ->
List.iter (fun (v, ex_r) ->
HX.iter (fun _ d -> remove ~ex:ex_r d v) table
) to_filter;
loop ()
in
loop ()

let propagate_rel_t ~ex dom r =
let get r = Bitlist_domains_uf.entry dom r in
match r with
| Rbinrel (op, x, y) ->
propagate_binrel ~ex op (get x) (get y)
| Rdistinct _ ->
(* No bitlist propagation *)
()

let propagate_interval_rel_t ~ex dom r =
let get r = Interval_domains_uf.entry dom r in
match r with
| Rbinrel (op, x, y) ->
propagate_interval_binrel ~ex op (get x) (get y)
| Rdistinct rs ->
propagate_interval_distinct ~ex dom rs

let propagate_view ~ex dom = function
| Cfun (r, f) -> propagate_fun_t ~ex dom r f
Expand Down Expand Up @@ -1084,11 +1164,39 @@ end = struct
true
| Rule | Rugt -> false

let simplify_distinct acts uf rs =
let table = HX.create (Array.length rs) in
let ex_ref = ref Ex.empty in
let non_constant_ref = ref [] in
let has_constant = ref false in
Array.iter (fun r ->
let r, ex = Uf.find_r uf r in
begin match HX.find table r with
| ex' -> acts.acts_add_eq ~ex:(Ex.union ex ex') X.top X.bot
| exception Not_found -> HX.add table r ex
end;
if X.is_constant r then has_constant := true else (
non_constant_ref := r :: !non_constant_ref ;
ex_ref := Ex.union ex !ex_ref
)
) rs;
(* Recall that simplification always occurs after propagation, so that
constants have been removed from the domains. *)
!has_constant &&
match !non_constant_ref with
| [] | [ _ ] ->
true
| non_constant ->
acts.acts_add_constraint ~ex:!ex_ref (distinct non_constant);
true

let simplify_rel_t uf acts = function
| Rbinrel (op, x, y) ->
let x, ex_x = Uf.find_r uf x in
let y, ex_y = Uf.find_r uf y in
simplify_binrel ~ex:(Explanation.union ex_x ex_y) acts op x y
| Rdistinct rs ->
simplify_distinct acts uf rs

let simplify_view uf acts = function
| Cfun (r, f) -> simplify_fun_t uf acts r f
Expand Down Expand Up @@ -1631,16 +1739,14 @@ let assume env uf la =
int_domain
in
(domain, int_domain, eqs, ss)
| L.Distinct (false, [rr; nrr]), _ when is_1bit rr ->
(* We don't (yet) support [distinct] in general, but we must
support it for case splits to avoid looping.
We are a bit more general and support it for 1-bit vectors,
for which `distinct` can be expressed using `bvnot`. *)
let not_nrr =
Shostak.Bitv.is_mine (Bitv.lognot (Shostak.Bitv.embed nrr))
| L.Distinct (false, ((rr :: _) as rs)), _ when is_bv_r rr ->
let c = Constraint.distinct rs in
let int_domain =
List.fold_left (fun int_domain r ->
Interval_domains.watch (explained ~ex c) r int_domain
) int_domain rs
in
(domain, int_domain, (Uf.LX.mkv_eq rr not_nrr, ex) :: eqs, ss)
(domain, int_domain, eqs, ss)
| _ -> (domain, int_domain, eqs, ss)
)
(domain, int_domain, [], env.size_splits)
Expand Down Expand Up @@ -1677,7 +1783,7 @@ let case_split env uf ~for_model =
[]
else
let domain =
Uf.GlobalDomains.find (module Bitlist_domains) (Uf.domains uf)
Uf.GlobalDomains.find (module Interval_domains) (Uf.domains uf)
in
(* Look for representatives with minimal, non-fully known, domain size.
Expand All @@ -1687,46 +1793,49 @@ let case_split env uf ~for_model =
[nunk] is the number of unknown bits. *)
let f_acc r acc =
let r, _ = Uf.find_r uf r in
let bl = Bitlist_domains.get r domain in
let nunk = Z.popcount (Bitlist.unknown_bits bl) in
if nunk = 0 then
acc
else
let int = Interval_domains.get r domain in
match Intervals.Int.value_opt int with
| Some _ -> acc
| None ->
let sz = bitwidth r in
let size =
Intervals.Int.fold (fun a int ->
let int = Intervals.Int.Interval.view int in
let lb = finite_lower_bound int.lb in
let ub = finite_upper_bound ~size:sz int.ub in
Z.(a + ub - lb + ~$1)
) Z.zero int
in
match acc with
| Some (nunk', _) when nunk > nunk' -> acc
| Some (nunk', xs) when nunk = nunk' ->
Some (nunk', SX.add r xs)
| _ -> Some (nunk, SX.singleton r)
| Some (size', _) when Z.compare size size' > 0 -> acc
| Some (size', xs) when Z.equal size size' ->
Some (size', SX.add r xs)
| _ -> Some (size, SX.singleton r)
in
let _, candidates =
let size, candidates =
match SX.fold f_acc env.terms None with
| Some (nunk, xs) -> nunk, xs
| None -> 0, SX.empty
| Some (size, xs) -> size, xs
| None -> Z.zero, SX.empty
in
(* For now, just pick a value for the most significant bit. *)
match SX.choose candidates with
| r ->
let rr, _ = Uf.find_r uf r in
let bl = Bitlist_domains.get rr domain in
let int = Interval_domains.get rr domain in
let r =
let es = Uf.rclass_of uf r in
let es = Uf.rclass_of uf rr in
try Uf.make uf (Expr.Set.choose es)
with Not_found -> r
in
let w = bitwidth r in
let unknown = Z.extract (Bitlist.unknown_bits bl) 0 w in
let bitidx = Z.numbits unknown - 1 in
let lhs =
Shostak.Bitv.is_mine @@
Bitv.extract w bitidx bitidx (Shostak.Bitv.embed r)
in
(* Just always pick zero for now. *)
let zero = Shostak.Bitv.is_mine Bitv.[ { bv = Cte Z.zero ; sz = 1 } ] in
let rhs_z, _ = Intervals.Int.lower_bound int in
let rhs = const (bitwidth r) (finite_lower_bound rhs_z) in
if Options.get_debug_bitv () then
Printer.print_dbg
~module_name:"Bitv_rel" ~function_name:"case_split"
"[BV-CS-1] Setting %a to 0" X.print lhs;
[ Uf.LX.mkv_eq lhs zero, true, Th_util.CS (Th_util.Th_bitv, Q.of_int 2) ]
"[BV-CS-1] Setting %a to %a" X.print r X.print rhs;
[ Uf.LX.mkv_eq r rhs,
true,
Th_util.CS (Th_util.Th_bitv, Q.of_bigint size) ]
| exception Not_found -> []

let add env uf r t =
Expand Down

0 comments on commit b8262c0

Please sign in to comment.