diff --git a/src/bitwuzla_mappings.default.ml b/src/bitwuzla_mappings.default.ml index e36bb975..8ec027b2 100644 --- a/src/bitwuzla_mappings.default.ml +++ b/src/bitwuzla_mappings.default.ml @@ -56,6 +56,10 @@ module Fresh_bitwuzla (B : Bitwuzla_cxx.S) : M = struct let ite cond t1 t2 = mk_term3 Kind.Ite cond t1 t2 + let forall _ = Fmt.failwith "Bitwuzla_mappings: forall not implemented" + + let exists _ = Fmt.failwith "Bitwuzla_mappings: exists not implemented" + module Types = struct let int = Obj.magic 0xdeadc0de diff --git a/src/cvc5_mappings.default.ml b/src/cvc5_mappings.default.ml index 0135bb1c..68ffa2dc 100644 --- a/src/cvc5_mappings.default.ml +++ b/src/cvc5_mappings.default.ml @@ -54,6 +54,10 @@ module Fresh_cvc5 () = struct let ite cond t1 t2 = Term.mk_term tm Kind.Ite [| cond; t1; t2 |] + let forall _ = Fmt.failwith "Cvc5_mappings: forall not implemented" + + let exists _ = Fmt.failwith "Cvc5_mappings: exists not implemented" + module Types = struct let int = Sort.mk_int_sort tm diff --git a/src/expr.ml b/src/expr.ml index f302b5dc..62df92a5 100644 --- a/src/expr.ml +++ b/src/expr.ml @@ -96,11 +96,11 @@ end module Set = PatriciaTree.MakeHashconsedSet (Key) () -let make (e : expr) = Hc.hashcons e [@@inline] +let[@inline] make e = Hc.hashcons e -let view (hte : t) : expr = hte.node [@@inline] +let[@inline] view (hte : t) = hte.node -let compare (hte1 : t) (hte2 : t) = compare hte1.tag hte2.tag [@@inline] +let[@inline] compare (hte1 : t) (hte2 : t) = compare hte1.tag hte2.tag let symbol s = make (Symbol s) @@ -273,7 +273,13 @@ let ptr base offset = make (Ptr { base; offset }) let app symbol args = make (App (symbol, args)) -let let_in vars expr = make (Binder (Let_in, vars, expr)) +let[@inline] binder bt vars expr = make (Binder (bt, vars, expr)) + +let let_in vars body = binder Let_in vars body + +let forall vars body = binder Forall vars body + +let exists vars body = binder Exists vars body let unop' ty op hte = make (Unop (ty, op, hte)) [@@inline] diff --git a/src/expr.mli b/src/expr.mli index 52fe0d1d..7ce0d672 100644 --- a/src/expr.mli +++ b/src/expr.mli @@ -61,6 +61,10 @@ val app : Symbol.t -> t list -> t val let_in : t list -> t -> t +val forall : t list -> t -> t + +val exists : t list -> t -> t + (** Smart unop constructor, applies simplifications at constructor level *) val unop : Ty.t -> Ty.Unop.t -> t -> t diff --git a/src/mappings.ml b/src/mappings.ml index a52b012b..edc0b383 100644 --- a/src/mappings.ml +++ b/src/mappings.ml @@ -667,9 +667,18 @@ module Make (M_with_make : M_with_make) : S_with_fresh = struct let ctx, e1 = encode_expr ctx e1 in let ctx, e2 = encode_expr ctx e2 in (ctx, M.Bitv.concat e1 e2) - | List _ | App _ | Binder _ -> assert false - - let encode_exprs ctx (es : Expr.t list) : symbol_ctx * M.term list = + | Binder (Forall, vars, body) -> + let ctx, vars = encode_exprs ctx vars in + let ctx, body = encode_expr ctx body in + (ctx, M.forall vars body) + | Binder (Exists, vars, body) -> + let ctx, vars = encode_exprs ctx vars in + let ctx, body = encode_expr ctx body in + (ctx, M.exists vars body) + | List _ | App _ | Binder _ -> + Fmt.failwith "Cannot encode expression: %a" Expr.pp hte + + and encode_exprs ctx (es : Expr.t list) : symbol_ctx * M.term list = List.fold_left (fun (ctx, es) e -> let ctx, e = encode_expr ctx e in diff --git a/src/mappings.nop.ml b/src/mappings.nop.ml index 9b48199d..394ba19a 100644 --- a/src/mappings.nop.ml +++ b/src/mappings.nop.ml @@ -50,6 +50,10 @@ module Nop = struct let ite _ = assert false + let forall _ _ = assert false + + let exists _ _ = assert false + module Types = struct let int = () diff --git a/src/mappings_intf.ml b/src/mappings_intf.ml index 13c3daed..1ecd9dac 100644 --- a/src/mappings_intf.ml +++ b/src/mappings_intf.ml @@ -49,6 +49,10 @@ module type M = sig val ite : term -> term -> term -> term + val forall : term list -> term -> term + + val exists : term list -> term -> term + module Types : sig val int : ty diff --git a/src/rewrite.ml b/src/rewrite.ml index 8de83752..531e9845 100644 --- a/src/rewrite.ml +++ b/src/rewrite.ml @@ -101,7 +101,18 @@ let rec rewrite_expr (type_map, expr_map) hte = expr_map vars in rewrite_expr (type_map, expr_map) e - | Binder (_, _, _) -> assert false + | Binder (((Forall | Exists) as quantifier), vars, e) -> + let type_map, vars = + List.fold_left + (fun (map, vars) e -> + match Expr.view e with + | App (sym, [ e ]) -> + let ty = Expr.ty e in + (Symb_map.add sym ty map, Expr.symbol { sym with ty } :: vars) + | _ -> assert false ) + (type_map, []) vars + in + Expr.make (Binder (quantifier, vars, rewrite_expr (type_map, expr_map) e)) (** Acccumulates types of symbols in [type_map] and calls rewrite_expr *) let rewrite_cmd type_map cmd = diff --git a/src/smtlib.ml b/src/smtlib.ml index 7c2ef652..86dab0af 100644 --- a/src/smtlib.ml +++ b/src/smtlib.ml @@ -5,6 +5,7 @@ open Dolmen module Loc = Std.Loc +(* FIXME: Dangerous global reference *) let custom_sorts = Hashtbl.create 10 let pp_loc fmt = function @@ -296,11 +297,11 @@ module Term = struct (* Ids can only be symbols. Any other expr here is super wrong *) assert false - let letand ?loc:_ (vars : t list) (expr : t) : t = Expr.let_in vars expr + let letand ?loc:_ (vars : t list) (body : t) : t = Expr.let_in vars body - let forall ?loc:_ = assert false + let forall ?loc:_ (vars : t list) (body : t) : t = Expr.forall vars body - let exists ?loc:_ = assert false + let exists ?loc:_ (vars : t list) (body : t) : t = Expr.exists vars body let match_ ?loc:_ = assert false diff --git a/src/z3_mappings.default.ml b/src/z3_mappings.default.ml index 6777e094..d5de8f66 100644 --- a/src/z3_mappings.default.ml +++ b/src/z3_mappings.default.ml @@ -54,6 +54,14 @@ module M = struct let ite cond e1 e2 = Z3.Boolean.mk_ite ctx cond e1 e2 + let forall vars body = + Z3.Quantifier.mk_forall_const ctx vars body None [] [] None None + |> Z3.Quantifier.expr_of_quantifier + + let exists vars body = + Z3.Quantifier.mk_exists_const ctx vars body None [] [] None None + |> Z3.Quantifier.expr_of_quantifier + module Types = struct let int = Z3.Arithmetic.Integer.mk_sort ctx @@ -72,11 +80,11 @@ module M = struct let to_ety sort = match Z3.Sort.get_sort_kind sort with | Z3enums.INT_SORT -> Ty.Ty_int - | Z3enums.REAL_SORT -> Ty.Ty_real - | Z3enums.BOOL_SORT -> Ty.Ty_bool - | Z3enums.SEQ_SORT -> Ty.Ty_str - | Z3enums.BV_SORT -> Ty.Ty_bitv (Z3.BitVector.get_size sort) - | Z3enums.FLOATING_POINT_SORT -> + | REAL_SORT -> Ty.Ty_real + | BOOL_SORT -> Ty.Ty_bool + | SEQ_SORT -> Ty.Ty_str + | BV_SORT -> Ty.Ty_bitv (Z3.BitVector.get_size sort) + | FLOATING_POINT_SORT -> let ebits = Z3.FloatingPoint.get_ebits ctx sort in let sbits = Z3.FloatingPoint.get_sbits ctx sort in Ty.Ty_fp (ebits + sbits) @@ -91,8 +99,8 @@ module M = struct let to_bool interp = match Z3.Boolean.get_bool_value interp with | Z3enums.L_TRUE -> true - | Z3enums.L_FALSE -> false - | Z3enums.L_UNDEF -> + | L_FALSE -> false + | L_UNDEF -> Fmt.failwith "Z3_mappings2: to_bool: something went terribly wrong!" let to_string interp = Z3.Seq.get_string ctx interp diff --git a/test/smt2/dune b/test/smt2/dune index 87666283..fa446046 100644 --- a/test/smt2/dune +++ b/test/smt2/dune @@ -10,6 +10,8 @@ test_core_const.smt2 test_core_true.smt2 test_empty.smt2 + test_exists.smt2 + test_forall.smt2 test_fp.smt2 test_lia.smt2 test_lra.smt2 diff --git a/test/smt2/test_exists.smt2 b/test/smt2/test_exists.smt2 new file mode 100644 index 00000000..f3eef287 --- /dev/null +++ b/test/smt2/test_exists.smt2 @@ -0,0 +1,2 @@ +(assert (exists ((x Int)) (= (+ x 1) 2))) +(check-sat) diff --git a/test/smt2/test_forall.smt2 b/test/smt2/test_forall.smt2 new file mode 100644 index 00000000..0819f01d --- /dev/null +++ b/test/smt2/test_forall.smt2 @@ -0,0 +1,2 @@ +(assert (forall ((x Int)) (= x x))) +(check-sat) diff --git a/test/smt2/test_smt2.t b/test/smt2/test_smt2.t index 043dab01..90d86f55 100644 --- a/test/smt2/test_smt2.t +++ b/test/smt2/test_smt2.t @@ -82,3 +82,9 @@ Tests smt2 with the --from-file argument: (model (x str "abcd") (y str "a")) + +Test Forall and Exists parsing: + $ smtml run test_forall.smt2 + sat + $ smtml run test_exists.smt2 + sat