From 3a38c07b196862efcfa742f8a88dc06cf55f92e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Basile=20Cl=C3=A9ment?= Date: Fri, 29 Mar 2024 16:08:38 +0100 Subject: [PATCH] feat(BV, CP): Add propagators for bvshl and bvlshr This patch adds interval and bitlist propagators for the bvshl (left shift) and bvlshr (logical right shift) in the intervals and bitlist domains for bit-vectors. The interval propagator for left shift needs to be written specially in order to properly deal with overflow, but the propagator for bvlshr is written using a generic propagator for (bi)-monotone functions. The bitlist propagator for bvshl is required because it needs to propagate information regarding low bits that are not tracked by intervals. However, I am not sure that the bitlist propagator for bvlshr is actually needed since it might be subsumed by the interval propagator for bvlshr (and consistency constraints) entirely, and we might want to remove it. --- src/lib/reasoners/bitlist.ml | 56 +++++++++++++++++++++++ src/lib/reasoners/bitlist.mli | 6 +++ src/lib/reasoners/bitv.ml | 4 +- src/lib/reasoners/bitv_rel.ml | 78 ++++++++++++++++++++++++++++++++- src/lib/reasoners/intervals.ml | 66 ++++++++++++++++++++++++++++ src/lib/reasoners/intervals.mli | 13 ++++++ src/lib/structures/expr.ml | 4 +- src/lib/structures/symbols.ml | 6 +++ src/lib/structures/symbols.mli | 1 + tests/bitvec_tests.ml | 38 ++++++++++++++++ 10 files changed, 268 insertions(+), 4 deletions(-) diff --git a/src/lib/reasoners/bitlist.ml b/src/lib/reasoners/bitlist.ml index 091537fdf..6b9adaa48 100644 --- a/src/lib/reasoners/bitlist.ml +++ b/src/lib/reasoners/bitlist.ml @@ -300,3 +300,59 @@ let mul a b = in concat (unknown (sz - width mid_bits - width low_bits) Ex.empty) @@ concat mid_bits low_bits + +let shl a b = + (* If the minimum value for [b] is larger than the bitwidth, the result is + zero. + + Otherwise, any low zero bit in [a] is also a zero bit in the result, and + the minimum value for [b] also accounts for that many minimum zeros (e.g. + ?000 shifted by at least 2 has at least 5 low zeroes). + + NB: we would like to use the lower bound from the interval domain for [b] + here. *) + match Z.to_int (increase_lower_bound b Z.zero) with + | n when n < width a -> + let low_zeros = Z.trailing_zeros @@ Z.lognot @@ a.bits_clr in + if low_zeros + n >= width a then + exact (width a) Z.zero (Ex.union (explanation a) (explanation b)) + else if low_zeros + n > 0 then + concat (unknown (width a - low_zeros - n) Ex.empty) @@ + exact (low_zeros + n) Z.zero (Ex.union (explanation a) (explanation b)) + else + unknown (width a) Ex.empty + | _ | exception Z.Overflow -> + exact (width a) Z.zero (explanation b) + +let lshr a b = + (* If the minimum value for [b] is larger than the bitwidth, the result is + zero. + + Otherwise, any high zero bit in [a] is also a zero bit in the result, and + the minimum value for [b] also accounts for that many minimum zeros (e.g. + 000??? shifted by at least 2 is 00000?). + + NB: we would like to use the lower bound from the interval domain for [b] + here. *) + match Z.to_int (increase_lower_bound b Z.zero) with + | n when n < width a -> + let sz = width a in + if Z.testbit a.bits_clr (sz - 1) then (* MSB is zero *) + let low_msb_zero = Z.numbits @@ Z.extract (Z.lognot a.bits_clr) 0 sz in + let nb_msb_zeros = sz - low_msb_zero in + assert (nb_msb_zeros > 0); + let nb_zeros = nb_msb_zeros + n in + if nb_zeros >= sz then + exact sz Z.zero (Ex.union (explanation a) (explanation b)) + else + concat + (exact nb_zeros Z.zero (Ex.union (explanation a) (explanation b))) + (unknown (sz - nb_zeros) Ex.empty) + else if n > 0 then + concat + (exact n Z.zero (explanation b)) + (unknown (sz - n) Ex.empty) + else + unknown sz Ex.empty + | _ | exception Z.Overflow -> + exact (width a) Z.zero (explanation b) diff --git a/src/lib/reasoners/bitlist.mli b/src/lib/reasoners/bitlist.mli index cc1e753f8..14a69c888 100644 --- a/src/lib/reasoners/bitlist.mli +++ b/src/lib/reasoners/bitlist.mli @@ -121,6 +121,12 @@ val logxor : t -> t -> t val mul : t -> t -> t (** Multiplication. *) +val shl : t -> t -> t +(** Logical left shift. *) + +val lshr : t -> t -> t +(** Logical right shift. *) + val concat : t -> t -> t (** Bit-vector concatenation. *) diff --git a/src/lib/reasoners/bitv.ml b/src/lib/reasoners/bitv.ml index 385269dfd..75ca4413f 100644 --- a/src/lib/reasoners/bitv.ml +++ b/src/lib/reasoners/bitv.ml @@ -350,7 +350,8 @@ module Shostak(X : ALIEN) = struct | Op ( Concat | Extract _ | BV2Nat | BVnot | BVand | BVor | BVxor - | BVadd | BVsub | BVmul | BVudiv | BVurem) + | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr) -> true | _ -> false @@ -409,6 +410,7 @@ module Shostak(X : ALIEN) = struct | { f = Op ( BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr ); _ } -> X.term_embed t, [] | _ -> X.make t diff --git a/src/lib/reasoners/bitv_rel.ml b/src/lib/reasoners/bitv_rel.ml index c89c6128d..fc743066c 100644 --- a/src/lib/reasoners/bitv_rel.ml +++ b/src/lib/reasoners/bitv_rel.ml @@ -256,6 +256,12 @@ module Constraint : sig This uses the convention that [x % 0] is [x]. *) + val bvshl : X.r -> X.r -> X.r -> t + (** [bvshl r x y] is the constraint [r = x << y] *) + + val bvlshr : X.r -> X.r -> X.r -> t + (** [bvshl r x y] is the constraint [r = x >> y] *) + val bvule : X.r -> X.r -> t val bvugt : X.r -> X.r -> t @@ -273,6 +279,8 @@ end = struct | Band | Bor | Bxor (* Arithmetic operations *) | Badd | Bmul | Budiv | Burem + (* Shift operations *) + | Bshl | Blshr let pp_binop ppf = function | Band -> Fmt.pf ppf "bvand" @@ -282,6 +290,8 @@ end = struct | Bmul -> Fmt.pf ppf "bvmul" | Budiv -> Fmt.pf ppf "bvudiv" | Burem -> Fmt.pf ppf "bvurem" + | Bshl -> Fmt.pf ppf "bvshl" + | Blshr -> Fmt.pf ppf "bvlshr" let equal_binop op1 op2 = match op1, op2 with @@ -304,12 +314,18 @@ end = struct | Budiv, _ | _, Budiv -> false | Burem, Burem -> true + | Burem, _ | _, Burem -> false + + | Bshl, Bshl -> true + | Bshl, _ | _, Bshl -> false + + | Blshr, Blshr -> true let hash_binop : binop -> int = Hashtbl.hash let is_commutative = function | Band | Bor | Bxor | Badd | Bmul -> true - | Budiv | Burem -> false + | Budiv | Burem | Bshl | Blshr -> false let propagate_binop ~ex dx op dy dz = let open Domains.Ephemeral in @@ -343,6 +359,12 @@ end = struct (* TODO: full adder propagation *) () + | Bshl -> (* Only forward propagation for now *) + update ~ex dx (Bitlist.shl !!dy !!dz) + + | Blshr -> (* Only forward propagation for now *) + update ~ex dx (Bitlist.lshr !!dy !!dz) + | Bmul -> (* Only forward propagation for now *) update ~ex dx (Bitlist.mul !!dy !!dz) @@ -361,6 +383,12 @@ end = struct update ~ex dy @@ norm @@ Intervals.Int.sub !!dr !!dx; update ~ex dx @@ norm @@ Intervals.Int.sub !!dr !!dy + | Bshl -> (* Only forward propagation for now *) + update ~ex dr @@ Intervals.Int.bvshl ~size:sz !!dx !!dy + + | Blshr -> (* Only forward propagation for now *) + update ~ex dr @@ Intervals.Int.lshr !!dx !!dy + | Bmul -> (* Only forward propagation for now *) update ~ex dr @@ norm @@ Intervals.Int.mul !!dx !!dy @@ -574,6 +602,8 @@ end = struct let bvmul = cbinop Bmul let bvudiv = cbinop Budiv let bvurem = cbinop Burem + let bvshl = cbinop Bshl + let bvlshr = cbinop Blshr let crel r = hcons @@ Crel r @@ -729,6 +759,27 @@ end = struct ) else false + (* Add the constraint: r = x >> c *) + let add_lshr_const acts r x c = + let sz = bitwidth r in + match Z.to_int c with + | 0 -> add_eq acts r x + | n when n < sz -> + assert (n > 0); + let r_bitv = Shostak.Bitv.embed r in + let low_bits = + Shostak.Bitv.is_mine @@ + Bitv.extract sz n (sz - 1) (Shostak.Bitv.embed x) + in + add_eq acts + (Shostak.Bitv.is_mine @@ Bitv.extract sz 0 (sz - 1 - n) r_bitv) + low_bits; + add_eq_const acts + (Shostak.Bitv.is_mine @@ Bitv.extract sz (sz - n) (sz - 1) r_bitv) + Z.zero + | _ | exception Z.Overflow -> + add_eq_const acts r Z.zero + (* Ground evaluation rules for binary operators. *) let eval_binop op ty x y = match op with @@ -747,6 +798,18 @@ end = struct cast ty x else cast ty @@ Z.rem x y + | Bshl -> ( + match ty, Z.to_int y with + | Tbitv sz, y when y < sz -> + cast ty @@ Z.shift_left x y + | _ | exception Z.Overflow -> cast ty Z.zero + ) + | Blshr -> ( + match ty, Z.to_int y with + | Tbitv sz, y when y < sz -> + cast ty @@ Z.shift_right x y + | _ | exception Z.Overflow -> cast ty Z.zero + ) (* Constant simplification rules for binary operators. @@ -793,6 +856,17 @@ end = struct | Budiv | Burem -> false + (* shifts becomes a simple extraction when we know the right-hand side *) + | Bshl when X.is_constant y -> + add_shl_const acts r x (value y); + true + | Bshl -> false + + | Blshr when X.is_constant y -> + add_lshr_const acts r x (value y); + true + | Blshr -> false + (* Algebraic rewrite rules for binary operators. Rules based on constant simplifications are in [rw_binop_const]. *) @@ -864,6 +938,8 @@ let extract_binop = | BVmul -> Some bvmul | BVudiv -> Some bvudiv | BVurem -> Some bvurem + | BVshl -> Some bvshl + | BVlshr -> Some bvlshr | _ -> None let extract_constraints bcs uf r t = diff --git a/src/lib/reasoners/intervals.ml b/src/lib/reasoners/intervals.ml index 6abdd7faa..64a07753f 100644 --- a/src/lib/reasoners/intervals.ml +++ b/src/lib/reasoners/intervals.ml @@ -314,6 +314,47 @@ module ZEuclideanType = struct | Neg_infinite -> Pos_infinite | Pos_infinite -> Neg_infinite | Finite n -> Finite (Z.lognot n) + + (* Values larger than [max_int] are treated as +oo *) + let shift_left ?(max_int = max_int) x y = + match y with + | Neg_infinite -> + Fmt.invalid_arg "shl: must shift by nonnegative amount" + | Pos_infinite -> Pos_infinite + | Finite y when Z.sign y < 0 -> + Fmt.invalid_arg "shl: must shift by nonnegative amount" + | Finite y -> + match Z.to_int y with + | exception Z.Overflow -> Pos_infinite + | y -> + if y <= max_int then + match x with + | Neg_infinite -> Neg_infinite + | Pos_infinite -> Pos_infinite + | Finite x -> Finite (Z.shift_left x y) + else Pos_infinite + + let shift_right x y = + match y with + | Neg_infinite -> + invalid_arg "shift_right: must shift by nonnegative amount" + | Finite y when Z.sign y < 0 -> + invalid_arg "shift_right: must shift by nonnegative amount" + | Pos_infinite -> ( + match x with + | Pos_infinite -> invalid_arg "shift_right: undefined limit" + | _ -> zero + ) + | Finite y -> + match x with + | Neg_infinite -> Neg_infinite + | Pos_infinite -> Pos_infinite + | Finite x -> + match Z.to_int y with + | exception Z.Overflow -> + (* y > max_int -> x >> y = 0 since numbits x <= max_int *) + zero + | y -> Finite (Z.shift_right x y) end (* AlgebraicType interface for reals @@ -663,6 +704,31 @@ module Int = struct { lb = ZEuclideanType.zero ; ub = ZEuclideanType.pred i2.ub } ) u1 ) u2 + + let bvshl ~size u1 u2 = + assert (size > 0); + (* Values higher than [max_int] ultimately map to [0] *) + let max_int = size - 1 in + let zero_i = { lb = ZEuclideanType.zero ; ub = ZEuclideanType.zero } in + extract ~ofs:0 ~len:size @@ + of_set_nonempty @@ + map_to_set (fun i2 -> + assert (ZEuclideanType.sign i2.lb >= 0); + if ZEuclideanType.(compare i2.lb (finite @@ Z.of_int max_int)) > 0 then + (* if i2.lb > max_int, the result is always zero + must not call ZEuclideanType.shift_left or we will likely OOM *) + interval_set zero_i + else + (* equivalent to multiplication by a positive constant *) + approx_map_inc_to_set + (fun lb -> ZEuclideanType.shift_left lb i2.lb) + (fun ub -> ZEuclideanType.shift_left ~max_int ub i2.ub) + u1 + ) u2 + + let lshr u1 u2 = + of_set_nonempty @@ + map2_mon_to_set ZEuclideanType.shift_right Inc u1 Dec u2 end module Legacy = struct diff --git a/src/lib/reasoners/intervals.mli b/src/lib/reasoners/intervals.mli index ee505f791..5cf501ecf 100644 --- a/src/lib/reasoners/intervals.mli +++ b/src/lib/reasoners/intervals.mli @@ -88,6 +88,19 @@ module Int : sig theory, i.e. where [bvurem n 0] is [n]. [s] and [t] must be within the [0, 2^sz - 1] range. *) + + val bvshl : size:int -> t -> t -> t + (** [shl sz s t] computes an overapproximation of the left shift [s lsl t], + truncating the result to [sz] bits. + + [s] and [t] must only contain non-negative integers. *) + + val lshr : t -> t -> t + (** [lshr s t] computes an approximation of the logical right shift [s lsr t]. + + Note that the result of logical right shift is independent of bit width. + + [s] and [t] must only contain non-negative integers. *) end module Legacy : sig diff --git a/src/lib/structures/expr.ml b/src/lib/structures/expr.ml index 26d776288..0e059a0c3 100644 --- a/src/lib/structures/expr.ml +++ b/src/lib/structures/expr.ml @@ -3192,8 +3192,8 @@ module BV = struct (bvneg u) (* Shift operations *) - let bvshl s t = int2bv (size2 s t) Ints.(bv2nat s * (~$2 ** bv2nat t)) - let bvlshr s t = int2bv (size2 s t) Ints.(bv2nat s / (~$2 ** bv2nat t)) + let bvshl s t = mk_term (Op BVshl) [s; t] (type_info s) + let bvlshr s t = mk_term (Op BVlshr) [s; t] (type_info s) let bvashr s t = let m = size2 s t in ite (is (extract (m - 1) (m - 1) s) 0) diff --git a/src/lib/structures/symbols.ml b/src/lib/structures/symbols.ml index 420a9259d..20200eacd 100644 --- a/src/lib/structures/symbols.ml +++ b/src/lib/structures/symbols.ml @@ -45,6 +45,7 @@ type operator = | Extract of int * int (* lower bound * upper bound *) | BVnot | BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr | Int2BV of int | BV2Nat (* FP *) | Float @@ -194,6 +195,7 @@ let compare_operators op1 op2 = | Integer_log2 | Pow | Integer_round | BVnot | BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr | Int2BV _ | BV2Nat | Not_theory_constant | Is_theory_constant | Linear_dependency | Constr _ | Destruct _ | Tite) -> assert false @@ -358,6 +360,8 @@ module AEPrinter = struct | BVmul -> Fmt.pf ppf "bvmul" | BVudiv -> Fmt.pf ppf "bvudiv" | BVurem -> Fmt.pf ppf "bvurem" + | BVshl -> Fmt.pf ppf "bvshl" + | BVlshr -> Fmt.pf ppf "bvlshr" (* ArraysEx theory *) | Get -> Fmt.pf ppf "get" @@ -464,6 +468,8 @@ module SmtPrinter = struct | BVmul -> Fmt.pf ppf "bvmul" | BVudiv -> Fmt.pf ppf "bvudiv" | BVurem -> Fmt.pf ppf "bvurem" + | BVshl -> Fmt.pf ppf "bvshl" + | BVlshr -> Fmt.pf ppf "bvlshr" (* ArraysEx theory *) | Get -> Fmt.pf ppf "select" diff --git a/src/lib/structures/symbols.mli b/src/lib/structures/symbols.mli index 5a7a029a6..82122c1c2 100644 --- a/src/lib/structures/symbols.mli +++ b/src/lib/structures/symbols.mli @@ -45,6 +45,7 @@ type operator = | Extract of int * int (* lower bound * upper bound *) | BVnot | BVand | BVor | BVxor | BVadd | BVsub | BVmul | BVudiv | BVurem + | BVshl | BVlshr | Int2BV of int | BV2Nat (* FP *) | Float diff --git a/tests/bitvec_tests.ml b/tests/bitvec_tests.ml index 6a7b865d3..723cd3700 100644 --- a/tests/bitvec_tests.ml +++ b/tests/bitvec_tests.ml @@ -276,6 +276,44 @@ let test_bitlist_mul sz = let () = Test.check_exn (test_bitlist_mul 3) +let zshl sz a b = + match Z.to_int b with + | b when b < sz -> Z.extract (Z.shift_left a b) 0 sz + | _ | exception Z.Overflow -> Z.zero + +let test_interval_shl sz = + test_interval_binop ~count:1_000 + sz (zshl sz) (Intervals.Int.bvshl ~size:sz) + +let () = + Test.check_exn (test_interval_shl 3) + +let test_bitlist_shl sz = + test_bitlist_binop ~count:1_000 + sz (zshl sz) Bitlist.shl + +let () = + Test.check_exn (test_bitlist_shl 3) + +let zlshr a b = + match Z.to_int b with + | b -> Z.shift_right a b + | exception Z.Overflow -> Z.zero + +let test_interval_lshr sz = + test_interval_binop ~count:1_000 + sz zlshr Intervals.Int.lshr + +let () = + Test.check_exn (test_interval_lshr 3) + +let test_bitlist_lshr sz = + test_bitlist_binop ~count:1_000 + sz zlshr Bitlist.lshr + +let () = + Test.check_exn (test_bitlist_lshr 3) + let zudiv sz a b = if Z.equal b Z.zero then Z.extract Z.minus_one 0 sz