diff --git a/src/lib/reasoners/bitlist.ml b/src/lib/reasoners/bitlist.ml index 091537fdf..6b9adaa48 100644 --- a/src/lib/reasoners/bitlist.ml +++ b/src/lib/reasoners/bitlist.ml @@ -300,3 +300,59 @@ let mul a b = in concat (unknown (sz - width mid_bits - width low_bits) Ex.empty) @@ concat mid_bits low_bits + +let shl a b = + (* If the minimum value for [b] is larger than the bitwidth, the result is + zero. + + Otherwise, any low zero bit in [a] is also a zero bit in the result, and + the minimum value for [b] also accounts for that many minimum zeros (e.g. + ?000 shifted by at least 2 has at least 5 low zeroes). + + NB: we would like to use the lower bound from the interval domain for [b] + here. *) + match Z.to_int (increase_lower_bound b Z.zero) with + | n when n < width a -> + let low_zeros = Z.trailing_zeros @@ Z.lognot @@ a.bits_clr in + if low_zeros + n >= width a then + exact (width a) Z.zero (Ex.union (explanation a) (explanation b)) + else if low_zeros + n > 0 then + concat (unknown (width a - low_zeros - n) Ex.empty) @@ + exact (low_zeros + n) Z.zero (Ex.union (explanation a) (explanation b)) + else + unknown (width a) Ex.empty + | _ | exception Z.Overflow -> + exact (width a) Z.zero (explanation b) + +let lshr a b = + (* If the minimum value for [b] is larger than the bitwidth, the result is + zero. + + Otherwise, any high zero bit in [a] is also a zero bit in the result, and + the minimum value for [b] also accounts for that many minimum zeros (e.g. + 000??? shifted by at least 2 is 00000?). + + NB: we would like to use the lower bound from the interval domain for [b] + here. *) + match Z.to_int (increase_lower_bound b Z.zero) with + | n when n < width a -> + let sz = width a in + if Z.testbit a.bits_clr (sz - 1) then (* MSB is zero *) + let low_msb_zero = Z.numbits @@ Z.extract (Z.lognot a.bits_clr) 0 sz in + let nb_msb_zeros = sz - low_msb_zero in + assert (nb_msb_zeros > 0); + let nb_zeros = nb_msb_zeros + n in + if nb_zeros >= sz then + exact sz Z.zero (Ex.union (explanation a) (explanation b)) + else + concat + (exact nb_zeros Z.zero (Ex.union (explanation a) (explanation b))) + (unknown (sz - nb_zeros) Ex.empty) + else if n > 0 then + concat + (exact n Z.zero (explanation b)) + (unknown (sz - n) Ex.empty) + else + unknown sz Ex.empty + | _ | exception Z.Overflow -> + exact (width a) Z.zero (explanation b) diff --git a/src/lib/reasoners/bitlist.mli b/src/lib/reasoners/bitlist.mli index cc1e753f8..14a69c888 100644 --- a/src/lib/reasoners/bitlist.mli +++ b/src/lib/reasoners/bitlist.mli @@ -121,6 +121,12 @@ val logxor : t -> t -> t val mul : t -> t -> t (** Multiplication. *) +val shl : t -> t -> t +(** Logical left shift. *) + +val lshr : t -> t -> t +(** Logical right shift. *) + val concat : t -> t -> t (** Bit-vector concatenation. *) diff --git a/src/lib/reasoners/bitv.ml b/src/lib/reasoners/bitv.ml index 385269dfd..75ca4413f 100644 --- a/src/lib/reasoners/bitv.ml +++ b/src/lib/reasoners/bitv.ml @@ -350,7 +350,8 @@ module Shostak(X : ALIEN) = struct | Op ( Concat | Extract _ | BV2Nat | BVnot | BVand | BVor | BVxor - | BVadd | BVsub | BVmul | BVudiv | BVurem) + | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr) -> true | _ -> false @@ -409,6 +410,7 @@ module Shostak(X : ALIEN) = struct | { f = Op ( BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr ); _ } -> X.term_embed t, [] | _ -> X.make t diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index c89c6128d..fc743066c 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -256,6 +256,12 @@ module Constraint : sig This uses the convention that [x % 0] is [x]. *) + val bvshl : X.r -> X.r -> X.r -> t + (** [bvshl r x y] is the constraint [r = x << y] *) + + val bvlshr : X.r -> X.r -> X.r -> t + (** [bvshl r x y] is the constraint [r = x >> y] *) + val bvule : X.r -> X.r -> t val bvugt : X.r -> X.r -> t @@ -273,6 +279,8 @@ end = struct | Band | Bor | Bxor (* Arithmetic operations *) | Badd | Bmul | Budiv | Burem + (* Shift operations *) + | Bshl | Blshr let pp_binop ppf = function | Band -> Fmt.pf ppf "bvand" @@ -282,6 +290,8 @@ end = struct | Bmul -> Fmt.pf ppf "bvmul" | Budiv -> Fmt.pf ppf "bvudiv" | Burem -> Fmt.pf ppf "bvurem" + | Bshl -> Fmt.pf ppf "bvshl" + | Blshr -> Fmt.pf ppf "bvlshr" let equal_binop op1 op2 = match op1, op2 with @@ -304,12 +314,18 @@ end = struct | Budiv, _ | _, Budiv -> false | Burem, Burem -> true + | Burem, _ | _, Burem -> false + + | Bshl, Bshl -> true + | Bshl, _ | _, Bshl -> false + + | Blshr, Blshr -> true let hash_binop : binop -> int = Hashtbl.hash let is_commutative = function | Band | Bor | Bxor | Badd | Bmul -> true - | Budiv | Burem -> false + | Budiv | Burem | Bshl | Blshr -> false let propagate_binop ~ex dx op dy dz = let open Domains.Ephemeral in @@ -343,6 +359,12 @@ end = struct (* TODO: full adder propagation *) () + | Bshl -> (* Only forward propagation for now *) + update ~ex dx (Bitlist.shl !!dy !!dz) + + | Blshr -> (* Only forward propagation for now *) + update ~ex dx (Bitlist.lshr !!dy !!dz) + | Bmul -> (* Only forward propagation for now *) update ~ex dx (Bitlist.mul !!dy !!dz) @@ -361,6 +383,12 @@ end = struct update ~ex dy @@ norm @@ Intervals.Int.sub !!dr !!dx; update ~ex dx @@ norm @@ Intervals.Int.sub !!dr !!dy + | Bshl -> (* Only forward propagation for now *) + update ~ex dr @@ Intervals.Int.bvshl ~size:sz !!dx !!dy + + | Blshr -> (* Only forward propagation for now *) + update ~ex dr @@ Intervals.Int.lshr !!dx !!dy + | Bmul -> (* Only forward propagation for now *) update ~ex dr @@ norm @@ Intervals.Int.mul !!dx !!dy @@ -574,6 +602,8 @@ end = struct let bvmul = cbinop Bmul let bvudiv = cbinop Budiv let bvurem = cbinop Burem + let bvshl = cbinop Bshl + let bvlshr = cbinop Blshr let crel r = hcons @@ Crel r @@ -729,6 +759,27 @@ end = struct ) else false + (* Add the constraint: r = x >> c *) + let add_lshr_const acts r x c = + let sz = bitwidth r in + match Z.to_int c with + | 0 -> add_eq acts r x + | n when n < sz -> + assert (n > 0); + let r_bitv = Shostak.Bitv.embed r in + let low_bits = + Shostak.Bitv.is_mine @@ + Bitv.extract sz n (sz - 1) (Shostak.Bitv.embed x) + in + add_eq acts + (Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (sz - 1 - n) r_bitv) + low_bits; + add_eq_const acts + (Shostak.Bitv.is_mine @@ Bitv.extract sz (sz - n) (sz - 1) r_bitv) + Z.zero + | _ | exception Z.Overflow -> + add_eq_const acts r Z.zero + (* Ground evaluation rules for binary operators. *) let eval_binop op ty x y = match op with @@ -747,6 +798,18 @@ end = struct cast ty x else cast ty @@ Z.rem x y + | Bshl -> ( + match ty, Z.to_int y with + | Tbitv sz, y when y < sz -> + cast ty @@ Z.shift_left x y + | _ | exception Z.Overflow -> cast ty Z.zero + ) + | Blshr -> ( + match ty, Z.to_int y with + | Tbitv sz, y when y < sz -> + cast ty @@ Z.shift_right x y + | _ | exception Z.Overflow -> cast ty Z.zero + ) (* Constant simplification rules for binary operators. @@ -793,6 +856,17 @@ end = struct | Budiv | Burem -> false + (* shifts becomes a simple extraction when we know the right-hand side *) + | Bshl when X.is_constant y -> + add_shl_const acts r x (value y); + true + | Bshl -> false + + | Blshr when X.is_constant y -> + add_lshr_const acts r x (value y); + true + | Blshr -> false + (* Algebraic rewrite rules for binary operators. Rules based on constant simplifications are in [rw_binop_const]. *) @@ -864,6 +938,8 @@ let extract_binop = | BVmul -> Some bvmul | BVudiv -> Some bvudiv | BVurem -> Some bvurem + | BVshl -> Some bvshl + | BVlshr -> Some bvlshr | _ -> None let extract_constraints bcs uf r t = diff --git a/src/lib/reasoners/intervals.ml b/src/lib/reasoners/intervals.ml index 6abdd7faa..64a07753f 100644 --- a/src/lib/reasoners/intervals.ml +++ b/src/lib/reasoners/intervals.ml @@ -314,6 +314,47 @@ module ZEuclideanType = struct | Neg_infinite -> Pos_infinite | Pos_infinite -> Neg_infinite | Finite n -> Finite (Z.lognot n) + + (* Values larger than [max_int] are treated as +oo *) + let shift_left ?(max_int = max_int) x y = + match y with + | Neg_infinite -> + Fmt.invalid_arg "shl: must shift by nonnegative amount" + | Pos_infinite -> Pos_infinite + | Finite y when Z.sign y < 0 -> + Fmt.invalid_arg "shl: must shift by nonnegative amount" + | Finite y -> + match Z.to_int y with + | exception Z.Overflow -> Pos_infinite + | y -> + if y <= max_int then + match x with + | Neg_infinite -> Neg_infinite + | Pos_infinite -> Pos_infinite + | Finite x -> Finite (Z.shift_left x y) + else Pos_infinite + + let shift_right x y = + match y with + | Neg_infinite -> + invalid_arg "shift_right: must shift by nonnegative amount" + | Finite y when Z.sign y < 0 -> + invalid_arg "shift_right: must shift by nonnegative amount" + | Pos_infinite -> ( + match x with + | Pos_infinite -> invalid_arg "shift_right: undefined limit" + | _ -> zero + ) + | Finite y -> + match x with + | Neg_infinite -> Neg_infinite + | Pos_infinite -> Pos_infinite + | Finite x -> + match Z.to_int y with + | exception Z.Overflow -> + (* y > max_int -> x >> y = 0 since numbits x <= max_int *) + zero + | y -> Finite (Z.shift_right x y) end (* AlgebraicType interface for reals @@ -663,6 +704,31 @@ module Int = struct { lb = ZEuclideanType.zero ; ub = ZEuclideanType.pred i2.ub } ) u1 ) u2 + + let bvshl ~size u1 u2 = + assert (size > 0); + (* Values higher than [max_int] ultimately map to [0] *) + let max_int = size - 1 in + let zero_i = { lb = ZEuclideanType.zero ; ub = ZEuclideanType.zero } in + extract ~ofs:0 ~len:size @@ + of_set_nonempty @@ + map_to_set (fun i2 -> + assert (ZEuclideanType.sign i2.lb >= 0); + if ZEuclideanType.(compare i2.lb (finite @@ Z.of_int max_int)) > 0 then + (* if i2.lb > max_int, the result is always zero + must not call ZEuclideanType.shift_left or we will likely OOM *) + interval_set zero_i + else + (* equivalent to multiplication by a positive constant *) + approx_map_inc_to_set + (fun lb -> ZEuclideanType.shift_left lb i2.lb) + (fun ub -> ZEuclideanType.shift_left ~max_int ub i2.ub) + u1 + ) u2 + + let lshr u1 u2 = + of_set_nonempty @@ + map2_mon_to_set ZEuclideanType.shift_right Inc u1 Dec u2 end module Legacy = struct diff --git a/src/lib/reasoners/intervals.mli b/src/lib/reasoners/intervals.mli index ee505f791..5cf501ecf 100644 --- a/src/lib/reasoners/intervals.mli +++ b/src/lib/reasoners/intervals.mli @@ -88,6 +88,19 @@ module Int : sig theory, i.e. where [bvurem n 0] is [n]. [s] and [t] must be within the [0, 2^sz - 1] range. *) + + val bvshl : size:int -> t -> t -> t + (** [shl sz s t] computes an overapproximation of the left shift [s lsl t], + truncating the result to [sz] bits. + + [s] and [t] must only contain non-negative integers. *) + + val lshr : t -> t -> t + (** [lshr s t] computes an approximation of the logical right shift [s lsr t]. + + Note that the result of logical right shift is independent of bit width. + + [s] and [t] must only contain non-negative integers. *) end module Legacy : sig diff --git a/src/lib/structures/expr.ml b/src/lib/structures/expr.ml index 7c5bf335c..7e703e226 100644 --- a/src/lib/structures/expr.ml +++ b/src/lib/structures/expr.ml @@ -3156,8 +3156,8 @@ module BV = struct (bvneg u) (* Shift operations *) - let bvshl s t = int2bv (size2 s t) Ints.(bv2nat s * (~$2 ** bv2nat t)) - let bvlshr s t = int2bv (size2 s t) Ints.(bv2nat s / (~$2 ** bv2nat t)) + let bvshl s t = mk_term (Op BVshl) [s; t] (type_info s) + let bvlshr s t = mk_term (Op BVlshr) [s; t] (type_info s) let bvashr s t = let m = size2 s t in ite (is (extract (m - 1) (m - 1) s) 0) diff --git a/src/lib/structures/symbols.ml b/src/lib/structures/symbols.ml index 4eb4e1c3c..b92547ad7 100644 --- a/src/lib/structures/symbols.ml +++ b/src/lib/structures/symbols.ml @@ -45,6 +45,7 @@ type operator = | Extract of int * int (* lower bound * upper bound *) | BVnot | BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr | Int2BV of int | BV2Nat (* FP *) | Float @@ -194,6 +195,7 @@ let compare_operators op1 op2 = | Integer_log2 | Pow | Integer_round | BVnot | BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr | Int2BV _ | BV2Nat | Not_theory_constant | Is_theory_constant | Linear_dependency | Constr _ | Destruct _ | Tite) -> assert false @@ -358,6 +360,8 @@ module AEPrinter = struct | BVmul -> Fmt.pf ppf "bvmul" | BVudiv -> Fmt.pf ppf "bvudiv" | BVurem -> Fmt.pf ppf "bvurem" + | BVshl -> Fmt.pf ppf "bvshl" + | BVlshr -> Fmt.pf ppf "bvlshr" (* ArraysEx theory *) | Get -> Fmt.pf ppf "get" @@ -464,6 +468,8 @@ module SmtPrinter = struct | BVmul -> Fmt.pf ppf "bvmul" | BVudiv -> Fmt.pf ppf "bvudiv" | BVurem -> Fmt.pf ppf "bvurem" + | BVshl -> Fmt.pf ppf "bvshl" + | BVlshr -> Fmt.pf ppf "bvlshr" (* ArraysEx theory *) | Get -> Fmt.pf ppf "select" diff --git a/src/lib/structures/symbols.mli b/src/lib/structures/symbols.mli index 5a7a029a6..82122c1c2 100644 --- a/src/lib/structures/symbols.mli +++ b/src/lib/structures/symbols.mli @@ -45,6 +45,7 @@ type operator = | Extract of int * int (* lower bound * upper bound *) | BVnot | BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr | Int2BV of int | BV2Nat (* FP *) | Float diff --git a/tests/bitvec_tests.ml b/tests/bitvec_tests.ml index 6a7b865d3..723cd3700 100644 --- a/tests/bitvec_tests.ml +++ b/tests/bitvec_tests.ml @@ -276,6 +276,44 @@ let test_bitlist_mul sz = let () = Test.check_exn (test_bitlist_mul 3) +let zshl sz a b = + match Z.to_int b with + | b when b < sz -> Z.extract (Z.shift_left a b) 0 sz + | _ | exception Z.Overflow -> Z.zero + +let test_interval_shl sz = + test_interval_binop ~count:1_000 + sz (zshl sz) (Intervals.Int.bvshl ~size:sz) + +let () = + Test.check_exn (test_interval_shl 3) + +let test_bitlist_shl sz = + test_bitlist_binop ~count:1_000 + sz (zshl sz) Bitlist.shl + +let () = + Test.check_exn (test_bitlist_shl 3) + +let zlshr a b = + match Z.to_int b with + | b -> Z.shift_right a b + | exception Z.Overflow -> Z.zero + +let test_interval_lshr sz = + test_interval_binop ~count:1_000 + sz zlshr Intervals.Int.lshr + +let () = + Test.check_exn (test_interval_lshr 3) + +let test_bitlist_lshr sz = + test_bitlist_binop ~count:1_000 + sz zlshr Bitlist.lshr + +let () = + Test.check_exn (test_bitlist_lshr 3) + let zudiv sz a b = if Z.equal b Z.zero then Z.extract Z.minus_one 0 sz