diff --git a/src/lib/reasoners/bitlist.ml b/src/lib/reasoners/bitlist.ml index 6b9adaa483..e6a2c5c4f3 100644 --- a/src/lib/reasoners/bitlist.ml +++ b/src/lib/reasoners/bitlist.ml @@ -29,140 +29,152 @@ module Ex = Explanation exception Inconsistent of Ex.t -(** A bitlist representing the known bits of a bit-vector of width [width]. +(** A bitlist representing the known bits of an infinite-width bit-vector. + Negative numbers are represented in two's complement. Active bits in [bits_set] are necessarily equal to [1]. - Active bits in [bits_clr] are necessarily equal to [0]. + Active bits in [bits_unk] are not known and may be either [0] or [1]. + Bits that are active in neither [bits_set] nor [bits_unk] are equal to [0]. + + The sign is known (and equal to the sign of [bits_set]) if [bits_unk] is + positive, and the sign is unknown if [bits_unk] is negative. The explanation [ex] justifies that the bitlist applies. *) -type t = { width: int ; bits_set : Z.t ; bits_clr : Z.t ; ex : Ex.t } +type t = { bits_set : Z.t ; bits_unk : Z.t ; ex : Ex.t } + +let constant n ex = + { bits_set = n ; bits_unk = Z.zero ; ex } -let unknown width ex = - { width ; bits_set = Z.zero ; bits_clr = Z.zero ; ex } +let zero ex = constant Z.zero ex -let empty = - { width = 0 ; bits_set = Z.zero ; bits_clr = Z.zero ; ex = Ex.empty } +let empty = zero Ex.empty -let width { width; _ } = width +let unknown = { bits_set = Z.zero ; bits_unk = Z.minus_one ; ex = Ex.empty } let explanation { ex; _ } = ex -let exact width value ex = - { width - ; bits_set = Z.extract value 0 width - ; bits_clr = Z.extract (Z.lognot value) 0 width +let exact value ex = + { bits_set = value + ; bits_unk = Z.zero ; ex } let equal b1 b2 = - b1.width = b2.width && Z.equal b1.bits_set b2.bits_set && - Z.equal b1.bits_clr b2.bits_clr + Z.equal b1.bits_unk b2.bits_unk -let ones b = { b with bits_clr = Z.zero } +let ones b = { b with bits_unk = Z.(b.bits_unk lor ~!(b.bits_set)) } -let zeroes b = { b with bits_set = Z.zero } +let zeroes b = + { b with bits_set = Z.zero ; bits_unk = Z.logor b.bits_unk b.bits_set } let add_explanation ~ex b = { b with ex = Ex.union b.ex ex } -let pp ppf { width; bits_set; bits_clr; ex } = +let pp ppf { bits_set; bits_unk; ex } = + begin if Z.sign bits_unk < 0 then + (* Sign is not known *) + Fmt.pf ppf "(?)" + else if Z.sign bits_set < 0 then + Fmt.pf ppf "(1)" + else + Fmt.pf ppf "(0)" + end; + let width = Z.numbits bits_unk in for i = width - 1 downto 0 do if Z.testbit bits_set i then Fmt.pf ppf "1" - else if Z.testbit bits_clr i then - Fmt.pf ppf "0" - else + else if Z.testbit bits_unk i then Fmt.pf ppf "?" + else + Fmt.pf ppf "0" done; if Options.(get_verbose () || get_unsat_core ()) then Fmt.pf ppf " %a" (Fmt.box Ex.print) ex -let bitlist ~width ~bits_set ~bits_clr ex = - if not (Z.equal (Z.logand bits_set bits_clr) Z.zero) then - raise @@ Inconsistent ex; - - { width; bits_set; bits_clr ; ex } - -let bits_known b = Z.logor b.bits_set b.bits_clr - -let num_unknown b = b.width - Z.popcount (bits_known b) +let unknown_bits b = b.bits_unk let value b = b.bits_set -let is_fully_known b = - Z.(equal (shift_right (bits_known b + ~$1) b.width) ~$1) +let is_fully_known b = Z.equal b.bits_unk Z.zero let intersect b1 b2 = - let width = b1.width in let bits_set = Z.logor b1.bits_set b2.bits_set in - let bits_clr = Z.logor b1.bits_clr b2.bits_clr in - bitlist ~width ~bits_set ~bits_clr - (Ex.union b1.ex b2.ex) + let bits_unk = Z.logand b1.bits_unk b2.bits_unk in + (* If there is a bit that is known in both bitlists with different values, + the intersection is empty. *) + let distinct = Z.logxor b1.bits_set b2.bits_set in + let known = Z.lognot (Z.logor b1.bits_unk b2.bits_unk) in + if not (Z.equal (Z.logand distinct known) Z.zero) then + raise @@ Inconsistent (Ex.union b1.ex b2.ex); + + { bits_set ; bits_unk ; ex = Ex.union b1.ex b2.ex } + +let extract b ofs len = + if len = 0 then empty + else + (* Always consistent *) + { bits_set = Z.extract b.bits_set ofs len + ; bits_unk = Z.extract b.bits_unk ofs len + ; ex = b.ex + } -let concat b1 b2 = - let bits_set = Z.(logor (b1.bits_set lsl b2.width) b2.bits_set) - and bits_clr = Z.(logor (b1.bits_clr lsl b2.width) b2.bits_clr) - in +let lognot b = (* Always consistent *) - { width = b1.width + b2.width - ; bits_set - ; bits_clr - ; ex = Ex.union b1.ex b2.ex - } + { b with bits_set = Z.(~!(b.bits_set lor b.bits_unk))} -let ( @ ) = concat +let ( ~! ) = lognot -let extract b i j = +let logor b1 b2 = + (* A bit is set in [b1 | b2] if it is set in either [b1] or [b2]. *) + let bits_set = Z.logor b1.bits_set b2.bits_set in + (* A bit is unknown in [b1 | b2] if it is unknown in either [b1] or [b2], + unless is set to [1] in either [b1] or [b2]. *) + let bits_unk = + Z.logand (Z.logor b1.bits_unk b2.bits_unk) + (Z.lognot bits_set) + in (* Always consistent *) - { width = j - i + 1 - ; bits_set = Z.extract b.bits_set i (j - i + 1) - ; bits_clr = Z.extract b.bits_clr i (j - i + 1) - ; ex = b.ex + { bits_set + ; bits_unk + ; ex = Ex.union b1.ex b2.ex } -let lognot b = - (* Always consistent *) - { b with bits_set = b.bits_clr; bits_clr = b.bits_set } +let ( lor ) = logor let logand b1 b2 = - let width = b1.width in let bits_set = Z.logand b1.bits_set b2.bits_set in - let bits_clr = Z.logor b1.bits_clr b2.bits_clr in + (* A bit is unknown in [b1 & b2] if it is unknown in both [b1] and [b2], or if + it is equal to [1] in either side and unknown in the other. *) + let bits_unk = + Z.logor (Z.logand b1.bits_set b2.bits_unk) @@ + Z.logor (Z.logand b1.bits_unk b2.bits_set) @@ + Z.logand b1.bits_unk b2.bits_unk + in (* Always consistent *) - { width - ; bits_set - ; bits_clr + { bits_set + ; bits_unk ; ex = Ex.union b1.ex b2.ex } -let logor b1 b2 = - let width = b1.width in - let bits_set = Z.logor b1.bits_set b2.bits_set in - let bits_clr = Z.logand b1.bits_clr b2.bits_clr in - (* Always consistent *) - { width - ; bits_set - ; bits_clr - ; ex = Ex.union b1.ex b2.ex - } +let ( land ) = logand let logxor b1 b2 = - let width = b1.width in + (* A bit is unknown in [b1 ^ b2] if it is unknown in either [b1] or [b2]. *) + let bits_unk = Z.logor b1.bits_unk b2.bits_unk in + (* Need to mask because [Z.logxor] will compute a bogus value for unknown + bits. *) let bits_set = - Z.logor - (Z.logand b1.bits_set b2.bits_clr) - (Z.logand b1.bits_clr b2.bits_set) - and bits_clr = - Z.logor - (Z.logand b1.bits_set b2.bits_set) - (Z.logand b1.bits_clr b2.bits_clr) + Z.logand + (Z.logxor b1.bits_set b2.bits_set) + (Z.lognot bits_unk) in (* Always consistent *) - { width - ; bits_set - ; bits_clr + { bits_set + ; bits_unk ; ex = Ex.union b1.ex b2.ex } +let ( lxor ) = logxor + (* The logic for the [increase_lower_bound] function below is described in section 4.1 of @@ -176,9 +188,12 @@ let logxor b1 b2 = (* [left_cl_can_set highest_cleared cleared_can_set] returns the least-significant bit that is: - More significant than [highest_cleared], strictly; - - Set in [cleared_can_set] *) + - Set in [cleared_can_set] + + Raises [Not_found] if there are no such bit. *) let left_cl_can_set highest_cleared cleared_can_set = let can_set = Z.(cleared_can_set asr highest_cleared) in + if Z.equal can_set Z.zero then raise Not_found; highest_cleared + Z.trailing_zeros can_set let increase_lower_bound b lb = @@ -188,7 +203,7 @@ let increase_lower_bound b lb = [cleared_bits] contains the bits that were set in [lb] and got cleared in [r]; conversely, [set_bits] contains the bits that were cleared in [lb] and got set in [r]. *) - let r = Z.logor b.bits_set (Z.logand lb (Z.lognot b.bits_clr)) in + let r = Z.logor b.bits_set (Z.logand lb b.bits_unk) in let cleared_bits = Z.logand lb (Z.lognot r) in let set_bits = Z.logand (Z.lognot lb) r in @@ -227,10 +242,10 @@ let increase_lower_bound b lb = as when the most-significant changed bit was 0 and is now 1 (see [if] case above). *) let bit_to_clear = Z.numbits cleared_bits in - let cleared_can_set = Z.lognot @@ Z.logor r b.bits_clr in + let cleared_can_set = + Z.logand (Z.lognot r) (Z.logor b.bits_set b.bits_unk) + in let bit_to_set = left_cl_can_set bit_to_clear cleared_can_set in - if bit_to_set >= b.width then - raise Not_found; let r = Z.logor r Z.(~$1 lsl bit_to_set) in let mask = Z.(minus_one lsl bit_to_set) in Z.logand r @@ Z.logor mask b.bits_set @@ -238,70 +253,81 @@ let increase_lower_bound b lb = let decrease_upper_bound b ub = (* x <= ub <-> ~ub <= ~x *) - let sz = width b in - assert (Z.numbits ub <= sz); - let nub = - increase_lower_bound (lognot b) (Z.extract (Z.lognot ub) 0 sz) - in - Z.extract (Z.lognot nub) 0 sz + Z.lognot @@ increase_lower_bound (lognot b) (Z.lognot ub) let fold_domain f b acc = - if b.width <= 0 then + (* If [bits_unk] is negative, the domain is infinite. *) + if Z.sign b.bits_unk < 0 then invalid_arg "Bitlist.fold_domain"; + let width = Z.numbits b.bits_unk in let rec fold_domain_aux ofs b acc = - if ofs >= b.width then ( + if ofs >= width then ( assert (is_fully_known b); f (value b) acc - ) else if Z.testbit b.bits_clr ofs || Z.testbit b.bits_set ofs then + ) else if not (Z.testbit b.bits_unk ofs) then fold_domain_aux (ofs + 1) b acc else let mask = Z.(one lsl ofs) in + let bits_unk = Z.logand b.bits_unk (Z.lognot mask) in let acc = fold_domain_aux - (ofs + 1) { b with bits_clr = Z.logor b.bits_clr mask } acc + (ofs + 1) { b with bits_unk } acc in fold_domain_aux - (ofs + 1) { b with bits_set = Z.logor b.bits_set mask } acc + (ofs + 1) { b with bits_unk; bits_set = Z.logor b.bits_set mask } acc in fold_domain_aux 0 b acc +let shift_left a n = + { bits_set = Z.(a.bits_set lsl n) + ; bits_unk = Z.(a.bits_unk lsl n) + ; ex = a.ex + } + +let ( lsl ) = shift_left + +let shift_right a n = + { bits_set = Z.(a.bits_set asr n) + ; bits_unk = Z.(a.bits_unk asr n) + ; ex = a.ex + } + +let ( asr ) = shift_right + (* simple propagator: only compute known low bits *) let mul a b = - let sz = width a in - assert (width b = sz); - let ex = Ex.union (explanation a) (explanation b) in (* (a * 2^n) * (b * 2^m) = (a * b) * 2^(n + m) *) - let zeroes_a = Z.trailing_zeros @@ Z.lognot a.bits_clr in - let zeroes_b = Z.trailing_zeros @@ Z.lognot b.bits_clr in - if zeroes_a + zeroes_b >= sz then - exact sz Z.zero ex + let zeroes_a = Z.trailing_zeros @@ Z.logor a.bits_set a.bits_unk in + let zeroes_b = Z.trailing_zeros @@ Z.logor b.bits_set b.bits_unk in + if zeroes_a = max_int || zeroes_b = max_int then + zero ex else - let low_bits = - if zeroes_a + zeroes_b = 0 then empty - else exact (zeroes_a + zeroes_b) Z.zero ex - in - let a = extract a zeroes_a (zeroes_a + sz - width low_bits - 1) in - assert (width a + width low_bits = sz); - let b = extract b zeroes_b (zeroes_b + sz - width low_bits - 1) in - assert (width b + width low_bits = sz); + let a = a asr zeroes_a in + let b = b asr zeroes_b in + let zeroes = zeroes_a + zeroes_b in (* ((ah * 2^n) + al) * ((bh * 2^m) + bl) = al * bl (mod 2^(min n m)) *) - let low_a_known = Z.trailing_zeros @@ Z.lognot @@ bits_known a in - let low_b_known = Z.trailing_zeros @@ Z.lognot @@ bits_known b in + let low_a_known = Z.trailing_zeros @@ a.bits_unk in + let low_b_known = Z.trailing_zeros @@ b.bits_unk in let low_known = min low_a_known low_b_known in + let mid_bits = exact Z.(value a * value b) ex in let mid_bits = - if low_known = 0 then empty - else exact - low_known - Z.(extract (value a) 0 low_known * extract (value b) 0 low_known) - ex + if low_known = max_int then mid_bits + else extract mid_bits 0 low_known in - concat (unknown (sz - width mid_bits - width low_bits) Ex.empty) @@ - concat mid_bits low_bits + if low_known = max_int then + mid_bits lsl zeroes + else + let high_bits = + { bits_set = Z.zero + ; bits_unk = Z.minus_one + ; ex = Ex.empty } + in + ((high_bits lsl low_known) lor mid_bits) lsl zeroes -let shl a b = +let bvshl ~size:sz a b = (* If the minimum value for [b] is larger than the bitwidth, the result is zero. @@ -312,19 +338,21 @@ let shl a b = NB: we would like to use the lower bound from the interval domain for [b] here. *) match Z.to_int (increase_lower_bound b Z.zero) with - | n when n < width a -> - let low_zeros = Z.trailing_zeros @@ Z.lognot @@ a.bits_clr in - if low_zeros + n >= width a then - exact (width a) Z.zero (Ex.union (explanation a) (explanation b)) + | n when n < sz -> + let low_zeros = Z.trailing_zeros @@ Z.logor a.bits_set a.bits_unk in + if low_zeros + n >= sz then + extract (exact Z.zero (Ex.union (explanation a) (explanation b))) 0 sz else if low_zeros + n > 0 then - concat (unknown (width a - low_zeros - n) Ex.empty) @@ - exact (low_zeros + n) Z.zero (Ex.union (explanation a) (explanation b)) + ((extract unknown 0 (sz - low_zeros - n)) lsl (low_zeros + n)) lor + extract + (exact Z.zero (Ex.union (explanation a) (explanation b))) + 0 (low_zeros + n) else - unknown (width a) Ex.empty + extract unknown 0 sz | _ | exception Z.Overflow -> - exact (width a) Z.zero (explanation b) + constant Z.zero (explanation b) -let lshr a b = +let bvlshr ~size:sz a b = (* If the minimum value for [b] is larger than the bitwidth, the result is zero. @@ -335,24 +363,26 @@ let lshr a b = NB: we would like to use the lower bound from the interval domain for [b] here. *) match Z.to_int (increase_lower_bound b Z.zero) with - | n when n < width a -> - let sz = width a in - if Z.testbit a.bits_clr (sz - 1) then (* MSB is zero *) - let low_msb_zero = Z.numbits @@ Z.extract (Z.lognot a.bits_clr) 0 sz in + | n when n < sz -> + if not (Z.testbit a.bits_unk (sz - 1) || Z.testbit a.bits_set (sz - 1)) + then (* MSB is zero *) + let low_msb_zero = + Z.numbits @@ Z.extract (Z.logor a.bits_set a.bits_unk) 0 sz + in let nb_msb_zeros = sz - low_msb_zero in assert (nb_msb_zeros > 0); let nb_zeros = nb_msb_zeros + n in if nb_zeros >= sz then - exact sz Z.zero (Ex.union (explanation a) (explanation b)) + constant Z.zero (Ex.union (explanation a) (explanation b)) else - concat - (exact nb_zeros Z.zero (Ex.union (explanation a) (explanation b))) - (unknown (sz - nb_zeros) Ex.empty) + ( + constant Z.zero (Ex.union (explanation a) (explanation b)) + lsl (sz - nb_zeros) + ) lor (extract unknown 0 (sz - nb_zeros)) else if n > 0 then - concat - (exact n Z.zero (explanation b)) - (unknown (sz - n) Ex.empty) + (constant Z.zero (explanation b) lsl (sz - n)) lor + extract unknown 0 (sz - n) else - unknown sz Ex.empty + extract unknown 0 sz | _ | exception Z.Overflow -> - exact (width a) Z.zero (explanation b) + constant Z.zero (explanation b) diff --git a/src/lib/reasoners/bitlist.mli b/src/lib/reasoners/bitlist.mli index 14a69c888a..5d7779dffc 100644 --- a/src/lib/reasoners/bitlist.mli +++ b/src/lib/reasoners/bitlist.mli @@ -28,7 +28,9 @@ (** Bit-lists provide a domain on bit-vectors that represent the known bits sets to [1] and [0], respectively. - This module provides an implementation of bitlists and related operators.*) + This module provides an implementation of bitlists and related operators. + The bitlists provided by this module do not have a fixed width, and can + represent arbitrary-precision integers. *) type t (** The type of bitlists. @@ -49,22 +51,15 @@ val pp : t Fmt.t exception Inconsistent of Explanation.t (** Exception raised when an inconsistency is detected. *) -val unknown : int -> Explanation.t -> t -(** [unknown w ex] returns an bitlist of width [w] with no known bits. *) - -val empty : t -(** An empty bitlist of width 0 and no explanation. *) - -val width : t -> int -(** Returns the width of the bitlist. *) +val unknown : t +(** [unknown] is a bitlist that repersents all integers. *) val explanation : t -> Explanation.t (** Returns the explanation associated with the bitlist. See the type-level documentation for details. *) -val exact : int -> Z.t -> Explanation.t -> t -(** [exact w v ex] returns a bitlist of width [w] that represents the [w]-bits - constant [v]. *) +val exact : Z.t -> Explanation.t -> t +(** [exact v ex] returns a bitlist that represents the constant [v]. *) val equal : t -> t -> bool (** [equal b1 b2] returns [true] if the bitlists [b1] and [b2] are equal, albeit @@ -82,16 +77,15 @@ val add_explanation : ex:Explanation.t -> t -> t (** [add_explanation ~ex b] adds the explanation [ex] to the bitlist [b]. The returned bitlist has both the explanation of [b] and [ex] as explanation. *) -val bits_known : t -> Z.t -(** [bits_known b] returns the sets of bits known to be either [1] or [0] as a - bitmask. *) +val unknown_bits : t -> Z.t +(** [unknown_bits b] returns the set of unknown (or undetermined) bits in [b]. -val num_unknown : t -> int -(** [num_unknown b] returns the number of bits unknown in [b]. *) + The value of [Z.logand (Z.lognot (unknown_bits b)) n] is the same for any + [n] in the set represented by the bitlist [b]. *) val is_fully_known : t -> bool (** [is_fully_known b] determines if there are unknown bits in [b] or not. - [is_fully_known b] is [true] iff [num_unknown b] is [0]. *) + [is_fully_known b] is [true] iff [unknown_bits b] is [Z.zero]. *) val value : t -> Z.t (** [value b] returns the value associated with the bitlist [b]. If the bitlist @@ -110,33 +104,38 @@ val lognot : t -> t (** [lognot b] swaps the bits that are set and cleared. *) val logand : t -> t -> t -(** Bitwise and. *) +(** Bit-wise and. *) val logor : t -> t -> t -(** Bitwise or. *) +(** Bit-wise or. *) val logxor : t -> t -> t -(** Bitwise xor. *) +(** Bit-wise xor. *) val mul : t -> t -> t (** Multiplication. *) -val shl : t -> t -> t -(** Logical left shift. *) +val bvshl : size:int -> t -> t -> t +(** Logical left shift, truncated to the [size] least significant bits. *) -val lshr : t -> t -> t -(** Logical right shift. *) +val bvlshr : size:int -> t -> t -> t +(** Logical right shift, truncated to the [size] least significant bits. *) -val concat : t -> t -> t -(** Bit-vector concatenation. *) +val shift_left : t -> int -> t +(** Shifts to the left. Equivalent to a multiplication by a power of [2]. The + second argument must be nonnegative. *) -val ( @ ) : t -> t -> t -(** Alias for [concat]. *) +val shift_right : t -> int -> t +(** Shifts to the right. This is an arithmetic shift, equivalent to a division + by a power of [2] with rounding towards -oo. The second argument must be + nonnegative. *) val extract : t -> int -> int -> t -(** [extract b i j] returns the bitlist from index [i] to index [j] inclusive. +(** [extract b off len] returns a nonnegative bitlist corresponding to bits + [off] to [off + len - 1] of [b]. - The resulting bitlist has length [j - i + 1]. *) + {b Note}: This uses the same arguments as [Z.extract], not the arguments + from the SMT-LIB's [extract] primitive. *) val increase_lower_bound : t -> Z.t -> Z.t (** [increase_lower_bound b lb] returns the smallest integer [lb' >= lb] that @@ -150,6 +149,26 @@ val decrease_upper_bound : t -> Z.t -> Z.t @raise Not_found if no such integer exists. *) +(** {2 Prefix and infix operators} *) + +val ( land ) : t -> t -> t +(** Bit-wise logical and [logand]. *) + +val ( lor ) : t -> t -> t +(** Bit-wise logical inclusive or [logor]. *) + +val ( lxor ) : t -> t -> t +(** Bit-wise logical exclusive xor [logxor]. *) + +val ( ~! ) : t -> t +(** Bit-wise logical negation [lognot]. *) + +val ( lsl ) : t -> int -> t +(** Bit-wise shift to the left [shift_left]. *) + +val ( asr ) : t -> int -> t +(** Bit-wise shift to the right [shift_right]. *) + (**/**) (** [fold_finite_domain f i acc] accumulates [f] on all the elements of [i] (in diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index fc743066cb..72224c698c 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -75,6 +75,9 @@ let is_bv_ty = function let is_bv_r r = is_bv_ty @@ X.type_info r +let bitwidth r = + match X.type_info r with Tbitv n -> n | _ -> assert false + module Interval_domain = struct type t = Intervals.Int.t @@ -112,11 +115,7 @@ module Interval_domain = struct Intervals.Int.of_bounds ?ex (Closed n) (Closed n) let fold_leaves f r int acc = - let width = - match X.type_info r with - | Tbitv n -> n - | _ -> assert false - in + let width = bitwidth r in let j, acc = List.fold_left (fun (j, acc) { Bitv.bv; sz } -> (* sz = j - i + 1 => i = j - sz + 1 *) @@ -170,49 +169,53 @@ module Domain : Rel_utils.Domain with type t = Bitlist.t = struct let filter_ty = is_bv_ty - let fold_signed f { Bitv.value; negated } bl acc = - let bl = if negated then lognot bl else bl in + let fold_signed sz f { Bitv.value; negated } bl acc = + let bl = if negated then extract (lognot bl) 0 sz else bl in f value bl acc let fold_leaves f r bl acc = - fst @@ List.fold_left (fun (acc, bl) { Bitv.bv; sz } -> + let sz = bitwidth r in + let (acc, _, _) = List.fold_left (fun (acc, bl, w) { Bitv.bv; sz } -> (* Extract the bitlist associated with the current component *) - let mid = width bl - sz in - let bl_tail = - if mid = 0 then empty else - extract bl 0 (mid - 1) - in - let bl = extract bl mid (width bl - 1) in + let mid = w - sz in + let bl_tail = extract bl 0 mid in + let bl = extract bl mid (w - mid) in match bv with | Bitv.Cte z -> + assert (Z.numbits z <= sz); (* Nothing to update, but still check for consistency! *) - ignore @@ intersect bl (exact sz z Ex.empty); - acc, bl_tail - | Other r -> fold_signed f r bl acc, bl_tail + ignore @@ intersect bl (exact z Ex.empty); + acc, bl_tail, mid + | Other r -> fold_signed sz f r bl acc, bl_tail, mid | Ext (r, r_size, i, j) -> (* r = bl -> r = ?^(r_size - j - 1) @ bl @ ?^i *) - assert (i + r_size - j - 1 + width bl = r_size); - let hi = Bitlist.unknown (r_size - j - 1) Ex.empty in - let lo = Bitlist.unknown i Ex.empty in - fold_signed f r (hi @ bl @ lo) acc, bl_tail - ) (acc, bl) (Shostak.Bitv.embed r) - - let map_signed f { Bitv.value; negated } = + assert (i + r_size - j - 1 + w - mid = r_size); + let hi = Bitlist.(extract unknown 0 (r_size - j - 1)) in + let lo = Bitlist.(extract unknown 0 i) in + let bl_hd = Bitlist.((hi lsl (j + 1)) lor (bl lsl i) lor lo) in + fold_signed r_size f r bl_hd acc, + bl_tail, + mid + ) (acc, bl, sz) (Shostak.Bitv.embed r) + in acc + + let map_signed sz f { Bitv.value; negated } = let bl = f value in - if negated then lognot bl else bl + if negated then extract (lognot bl) 0 sz else bl let map_leaves f r = List.fold_left (fun bl { Bitv.bv; sz } -> - concat bl @@ + bl lsl sz lor match bv with - | Bitv.Cte z -> exact sz z Ex.empty - | Other r -> map_signed f r - | Ext (r, _r_size, i, j) -> extract (map_signed f r) i j - ) empty (Shostak.Bitv.embed r) + | Bitv.Cte z -> extract (exact z Ex.empty) 0 sz + | Other r -> map_signed sz f r + | Ext (r, r_sz, i, j) -> + extract (map_signed r_sz f r) i (j - i + 1) + ) (exact Z.zero Ex.empty) (Shostak.Bitv.embed r) let unknown = function - | Ty.Tbitv n -> unknown n Ex.empty + | Ty.Tbitv n -> extract unknown 0 n | _ -> (* Only bit-vector values can have bitlist domains. *) invalid_arg "unknown" @@ -327,46 +330,51 @@ end = struct | Band | Bor | Bxor | Badd | Bmul -> true | Budiv | Burem | Bshl | Blshr -> false - let propagate_binop ~ex dx op dy dz = + let propagate_binop ~ex sz dx op dy dz = let open Domains.Ephemeral in + let norm bl = Bitlist.extract bl 0 sz in match op with | Band -> - update ~ex dx (Bitlist.logand !!dy !!dz); + update ~ex dx @@ norm @@ Bitlist.logand !!dy !!dz; (* Reverse propagation for y: if [x = y & z] then: - Any [1] in [x] must be a [1] in [y] - Any [0] in [x] that is also a [1] in [z] must be a [0] in [y] *) - update ~ex dy (Bitlist.ones !!dx); - update ~ex dy Bitlist.(logor (zeroes !!dx) (lognot (ones !!dz))); - update ~ex dz (Bitlist.ones !!dx); - update ~ex dz Bitlist.(logor (zeroes !!dx) (lognot (ones !!dy))) + update ~ex dy @@ norm @@ Bitlist.ones !!dx; + update ~ex dy @@ norm @@ + Bitlist.(logor (zeroes !!dx) (lognot (ones !!dz))); + update ~ex dz @@ norm @@ Bitlist.ones !!dx; + update ~ex dz @@ norm @@ + Bitlist.(logor (zeroes !!dx) (lognot (ones !!dy))) | Bor -> - update ~ex dx (Bitlist.logor !!dy !!dz); + update ~ex dx @@ norm @@ Bitlist.logor !!dy !!dz; (* Reverse propagation for y: if [x = y | z] then: - Any [0] in [x] must be a [0] in [y] - Any [1] in [x] that is also a [0] in [z] must be a [1] in [y] *) - update ~ex dy (Bitlist.zeroes !!dx); - update ~ex dy Bitlist.(logand (ones !!dx) (lognot (zeroes !!dz))); - update ~ex dz (Bitlist.zeroes !!dx); - update ~ex dz Bitlist.(logand (ones !!dx) (lognot (zeroes !!dy))) + update ~ex dy @@ norm @@ Bitlist.zeroes !!dx; + update ~ex dy @@ norm @@ + Bitlist.(logand (ones !!dx) (lognot (zeroes !!dz))); + update ~ex dz @@ norm @@ Bitlist.zeroes !!dx; + update ~ex dz @@ norm @@ + Bitlist.(logand (ones !!dx) (lognot (zeroes !!dy))) | Bxor -> - update ~ex dx (Bitlist.logxor !!dy !!dz); + update ~ex dx @@ norm @@ Bitlist.logxor !!dy !!dz; (* x = y ^ z <-> y = x ^ z *) - update ~ex dy (Bitlist.logxor !!dx !!dz); - update ~ex dz (Bitlist.logxor !!dx !!dy) + update ~ex dy @@ norm @@ Bitlist.logxor !!dx !!dz; + update ~ex dz @@ norm @@ Bitlist.logxor !!dx !!dy | Badd -> (* TODO: full adder propagation *) () | Bshl -> (* Only forward propagation for now *) - update ~ex dx (Bitlist.shl !!dy !!dz) + update ~ex dx (Bitlist.bvshl ~size:sz !!dy !!dz) | Blshr -> (* Only forward propagation for now *) - update ~ex dx (Bitlist.lshr !!dy !!dz) + update ~ex dx (Bitlist.bvlshr ~size:sz !!dy !!dz) | Bmul -> (* Only forward propagation for now *) - update ~ex dx (Bitlist.mul !!dy !!dz) + update ~ex dx @@ norm @@ Bitlist.mul !!dy !!dz | Budiv -> (* No bitlist propagation for now *) () @@ -434,13 +442,14 @@ end = struct let get r = handle dom r in match f with | Fbinop (op, x, y) -> - propagate_binop ~ex (get r) op (get x) (get y) + let n = bitwidth r in + propagate_binop ~ex n (get r) op (get x) (get y) let propagate_interval_fun_t ~ex dom r f = let get r = Interval_domains.Ephemeral.handle dom r in match f with | Fbinop (op, x, y) -> - let sz = match X.type_info r with Tbitv n -> n | _ -> assert false in + let sz = bitwidth r in propagate_interval_binop ~ex sz (get r) op (get x) (get y) type binrel = Rule | Rugt @@ -629,9 +638,6 @@ end = struct let propagate_interval ~ex c dom = propagate_interval_repr ~ex dom c.repr - let bitwidth r = - match X.type_info r with Tbitv n -> n | _ -> assert false - let const sz n = Shostak.Bitv.is_mine [ { bv = Cte (Z.extract n 0 sz); sz } ] @@ -979,15 +985,15 @@ let rec mk_eq ex lhs w z = applies to [r], exposes the equality [r = bl] as a list of Xliteral values (accumulated into [acc]) so that the union-find learns about the equality *) let add_eqs = - let rec aux x x_sz acc bl = - let known = Bitlist.bits_known bl in - let width = Bitlist.width bl in + let rec aux x x_sz acc width bl = + let known = Z.lognot (Bitlist.unknown_bits bl) in + let known = Z.extract known 0 width in let nbits = Z.numbits known in assert (nbits <= width); if nbits = 0 then acc else if nbits < width then - aux x x_sz acc (Bitlist.extract bl 0 (nbits - 1)) + aux x x_sz acc nbits (Bitlist.extract bl 0 nbits) else let nbits = Z.numbits (Z.extract (Z.lognot known) 0 width) in let v = Z.extract (Bitlist.value bl) nbits (width - nbits) in @@ -997,10 +1003,10 @@ let add_eqs = if nbits = 0 then lits @ acc else - aux x x_sz (lits @ acc) (Bitlist.extract bl 0 (nbits - 1)) + aux x x_sz (lits @ acc) nbits (Bitlist.extract bl 0 nbits) in - fun acc x bl -> - aux x (Bitlist.width bl) acc bl + fun acc x x_sz bl -> + aux x x_sz acc x_sz bl module Any_constraint = struct type t = @@ -1030,32 +1036,6 @@ end module QC = Uqueue.Make(Any_constraint) -(* Compute the number of most significant bits shared by [inf] and [sup]. - - Requires: [inf <= sup] - Ensures: - result is the greatest integer <= sz such that - inf = sup - - In particular, [result = sz] iff [inf = sup] and [result = 0] iff the most - significant bits of [inf] and [sup] differ. *) -let rec shared_msb sz inf sup = - let numbits_inf = Z.numbits inf in - let numbits_sup = Z.numbits sup in - assert (numbits_inf <= numbits_sup); - if numbits_inf = numbits_sup then - (* Top [sz - numbits_inf] bits are 0 in both; look at 1s *) - if numbits_inf = 0 then - sz - else - sz - numbits_inf + - shared_msb numbits_inf - (Z.extract (Z.lognot sup) 0 numbits_inf) - (Z.extract (Z.lognot inf) 0 numbits_inf) - else - (* Top [sz - numbits_sup] are 0 in both, the next significant bit differs *) - sz - numbits_sup - let finite_lower_bound = function | Intervals_intf.Unbounded -> Z.zero | Closed n -> n @@ -1082,23 +1062,24 @@ let finite_upper_bound ~size:sz = function For example, m = 48 and M = 52 (00110000 and 00110100 in binary) share their five most-significant bits, denoted [00110???]. Therefore, a bit-vector bl = [0??1???0] can be refined into [00110??0]. *) -let constrain_bitlist_from_interval bv int = +let constrain_bitlist_from_interval ~size:sz bv int = let open Domains.Ephemeral in - let sz = Bitlist.width !!bv in let inf, inf_ex = Intervals.Int.lower_bound int in let inf = finite_lower_bound inf in let sup, sup_ex = Intervals.Int.upper_bound int in let sup = finite_upper_bound ~size:sz sup in - let nshared = shared_msb sz inf sup in - if nshared > 0 then - let ex = Ex.union inf_ex sup_ex in - let shared_bl = - Bitlist.exact nshared (Z.extract inf (sz - nshared) nshared) ex - in - update ~ex bv @@ - Bitlist.concat shared_bl (Bitlist.unknown (sz - nshared) Ex.empty) + let distinct = Z.logxor inf sup in + (* If [distinct] is negative, [inf] and [sup] have different signs and there + are no significant shared bits. *) + if Z.sign distinct >= 0 then + let ofs = Z.numbits distinct in + update ~ex:Ex.empty bv @@ + Bitlist.( + exact Z.((inf asr ofs) lsl ofs) (Ex.union inf_ex sup_ex) lor + extract unknown 0 ofs + ) (* Algorithm 1 from @@ -1112,7 +1093,7 @@ let constrain_bitlist_from_interval bv int = This function is a wrapper calling [Bitlist.increase_lower_bound] and [Bitlist.decrease_upper_bound] on all the constituent interavals of an union; see the documentation of these functions for details. *) -let constrain_interval_from_bitlist int bv = +let constrain_interval_from_bitlist ~size:sz int bv = let open Interval_domains.Ephemeral in let ex = Bitlist.explanation bv in (* Handy wrapper around [of_complement] *) @@ -1129,7 +1110,7 @@ let constrain_interval_from_bitlist int bv = Intervals.Int.fold (fun acc i -> let { Intervals_intf.lb ; ub } = Intervals.Int.Interval.view i in let lb = finite_lower_bound lb in - let ub = finite_upper_bound ~size:(Bitlist.width bv) ub in + let ub = finite_upper_bound ~size:sz ub in let acc = match Bitlist.increase_lower_bound bv lb with | new_lb when Z.compare new_lb lb > 0 -> @@ -1227,7 +1208,8 @@ let propagate_all eqs bcs bdom idom = touch_pending queue; HX.iter (fun r () -> HX.replace bitlist_changed r (); - constrain_interval_from_bitlist + let sz = bitwidth r in + constrain_interval_from_bitlist ~size:sz Interval_domains.Ephemeral.(handle idom r) Domains.Ephemeral.(!!(handle bdom r)) ) touched; @@ -1241,7 +1223,8 @@ let propagate_all eqs bcs bdom idom = let bcs = Constraints.clear_pending bcs in while HX.length touched > 0 do HX.iter (fun r () -> - constrain_bitlist_from_interval + let sz = bitwidth r in + constrain_bitlist_from_interval ~size:sz Domains.Ephemeral.(handle bdom r) Interval_domains.Ephemeral.(!!(handle idom r)) ) touched; @@ -1250,8 +1233,9 @@ let propagate_all eqs bcs bdom idom = assert (QC.is_empty queue); HX.iter (fun r () -> + let sz = bitwidth r in HX.replace bitlist_changed r (); - constrain_interval_from_bitlist + constrain_interval_from_bitlist ~size:sz Interval_domains.Ephemeral.(handle idom r) Domains.Ephemeral.(!!(handle bdom r)) ) touched; @@ -1263,7 +1247,8 @@ let propagate_all eqs bcs bdom idom = let eqs = HX.fold (fun r () acc -> let d = Domains.Ephemeral.(!!(handle bdom r)) in - add_eqs acc (Shostak.Bitv.embed r) d + let sz = bitwidth r in + add_eqs acc (Shostak.Bitv.embed r) sz d ) bitlist_changed eqs in @@ -1379,7 +1364,7 @@ let case_split env uf ~for_model = [nunk] is the number of unknown bits. *) let f_acc r bl acc = - let nunk = Bitlist.num_unknown bl in + let nunk = Z.popcount (Bitlist.unknown_bits bl) in if nunk = 0 then acc else @@ -1411,8 +1396,8 @@ let case_split env uf ~for_model = match SX.choose candidates with | r -> let bl = Domains.get r domain in - let w = Bitlist.width bl in - let unknown = Z.extract (Z.lognot @@ Bitlist.bits_known bl) 0 w in + let w = bitwidth r in + let unknown = Z.extract (Bitlist.unknown_bits bl) 0 w in let bitidx = Z.numbits unknown - 1 in let lhs = Shostak.Bitv.is_mine @@ @@ -1451,5 +1436,6 @@ let assume_th_elt t th_elt _ = | _ -> t module Test = struct - let shared_msb = shared_msb + let shared_msb sz inf sup = + sz - Z.numbits (Z.logxor inf sup) end diff --git a/tests/bitvec_tests.ml b/tests/bitvec_tests.ml index 723cd37008..b6f95a58d9 100644 --- a/tests/bitvec_tests.ml +++ b/tests/bitvec_tests.ml @@ -1,6 +1,30 @@ open AltErgoLib open QCheck2 +module type FixedSizeBitVector = sig + type t = Bitlist.t + + val shl : t -> t -> t + + val lshr : t -> t -> t + + val mul : t -> t -> t +end + +let fixed_size_bit_vector n : (module FixedSizeBitVector) = + let open Bitlist in + let norm b = extract b 0 n in + let binop op x y = norm (op x y) in + (module struct + type t = Bitlist.t + + let shl a b = bvshl ~size:n a b + + let lshr a b = bvlshr ~size:n a b + + let mul = binop mul + end) + module IntSet : sig type t @@ -134,10 +158,16 @@ let bitlist sz = in let* (set_bits, clr_bits) = bitlist sz in let set_bits = - Bitlist.ones @@ Bitlist.exact sz set_bits Explanation.empty + Bitlist.extract ( + Bitlist.ones @@ + Bitlist.exact set_bits Explanation.empty + ) 0 sz in let clr_bits = - Bitlist.zeroes @@ Bitlist.exact sz (Z.lognot clr_bits) Explanation.empty + Bitlist.extract ( + Bitlist.zeroes @@ + Bitlist.exact (Z.extract (Z.lognot clr_bits) 0 sz) Explanation.empty + ) 0 sz in return @@ Bitlist.intersect set_bits clr_bits @@ -247,9 +277,7 @@ let test_bitlist_binop ~count sz zop bop = (Fmt.to_to_string Bitlist.pp)) Gen.(pair (bitlist sz) (bitlist sz)) (fun (s, t) -> - let u = bop s t in - Bitlist.width u = Bitlist.width s && - Bitlist.width u = Bitlist.width t && + let u = bop (fixed_size_bit_vector sz) s t in IntSet.subset (IntSet.map2 zop (of_bitlist s) (of_bitlist t)) (of_bitlist u)) @@ -271,7 +299,7 @@ let zmul sz a b = let test_bitlist_mul sz = test_bitlist_binop ~count:1_000 - sz (zmul sz) Bitlist.mul + sz (zmul sz) (fun (module BV) -> BV.mul) let () = Test.check_exn (test_bitlist_mul 3) @@ -290,7 +318,7 @@ let () = let test_bitlist_shl sz = test_bitlist_binop ~count:1_000 - sz (zshl sz) Bitlist.shl + sz (zshl sz) (fun (module BV) -> BV.shl) let () = Test.check_exn (test_bitlist_shl 3) @@ -309,7 +337,7 @@ let () = let test_bitlist_lshr sz = test_bitlist_binop ~count:1_000 - sz zlshr Bitlist.lshr + sz zlshr (fun (module BV) -> BV.lshr) let () = Test.check_exn (test_bitlist_lshr 3)