Skip to content

Commit

Permalink
Clean solver functors
Browse files Browse the repository at this point in the history
  • Loading branch information
filipeom committed Dec 24, 2023
1 parent 6f36aad commit 5beceda
Show file tree
Hide file tree
Showing 17 changed files with 126 additions and 151 deletions.
4 changes: 2 additions & 2 deletions bin/main.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open Encoding
open Cmdliner
module Z3_batch = Batch.Make (Z3_mappings)
module Z3_incremental = Incremental.Make (Z3_mappings)
module Z3_batch = Solver.Batch (Z3_mappings)
module Z3_incremental = Solver.Incremental (Z3_mappings)
module Interpret = Interpret.Make (Z3_batch)

let get_contents = function
Expand Down
72 changes: 0 additions & 72 deletions lib/batch.ml

This file was deleted.

5 changes: 0 additions & 5 deletions lib/batch.mli

This file was deleted.

3 changes: 1 addition & 2 deletions lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@
(modules
ast
;axioms
batch
eval_numeric
expr
incremental
interpret
interpret_intf
lexer
Expand All @@ -25,6 +23,7 @@
parser
params
run
solver
solver_intf
symbol
ty
Expand Down
55 changes: 0 additions & 55 deletions lib/incremental.ml

This file was deleted.

5 changes: 0 additions & 5 deletions lib/incremental.mli

This file was deleted.

113 changes: 113 additions & 0 deletions lib/solver.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
exception Unknown

let ( let+ ) o f = Option.map f o

module Base (M : Mappings_intf.S) = struct
let solver_time = ref 0.0
let solver_count = ref 0

let time_call f acc =
let start = Stdlib.Sys.time () in
let ret = f () in
acc := !acc +. (Stdlib.Sys.time () -. start);
ret

let update_param_values params =
M.update_param_value Timeout (Params.get params Timeout);
M.update_param_value Model (Params.get params Model);
M.update_param_value Unsat_core (Params.get params Unsat_core);
M.update_param_value Ematching (Params.get params Ematching)

let interrupt () = M.interrupt ()
end

module Batch (Mappings : Mappings_intf.S) = struct
include Base (Mappings)

type solver = Mappings.solver

type t =
{ solver : solver
; mutable top : Expr.t list
; stack : Expr.t list Stack.t
}

let create ?params () =
Option.iter update_param_values params;
{ solver = Mappings.mk_solver (); top = []; stack = Stack.create () }

let clone ({ solver; top; stack } : t) : t =
{ solver; top; stack = Stack.copy stack }

let push ({ top; stack; _ } : t) : unit = Stack.push top stack

let pop (s : t) (lvl : int) : unit =
assert (lvl <= Stack.length s.stack);
for _ = 1 to lvl do
s.top <- Stack.pop s.stack
done

let reset (s : t) =
Mappings.reset s.solver;
Stack.clear s.stack;
s.top <- []

let add (s : t) (es : Expr.t list) : unit = s.top <- es @ s.top
let get_assertions (s : t) : Expr.t list = s.top [@@inline]

let check (s : t) (es : Expr.t list) : bool =
let es' = es @ s.top in
solver_count := !solver_count + 1;
let sat = time_call (fun () -> Mappings.check s.solver es') solver_time in
match Mappings.satisfiability sat with
| Mappings_intf.Satisfiable -> true
| Mappings_intf.Unsatisfiable -> false
| Mappings_intf.Unknown -> raise Unknown

let get_value (solver : t) (e : Expr.t) : Expr.t =
match Mappings.solver_model solver.solver with
| Some m -> Expr.(Val (Mappings.value m e) @: e.ty)
| None -> assert false

let model ?(symbols : Symbol.t list option) (s : t) : Model.t option =
let+ model = Mappings.solver_model s.solver in
Mappings.values_of_model ?symbols model
end

module Incremental (Mappings : Mappings_intf.S) = struct
include Base (Mappings)

type t = Mappings.solver
type solver = t

let create ?params () : t =
Option.iter update_param_values params;
Mappings.mk_solver ()

let clone (solver : t) : t = Mappings.translate solver
let push (solver : t) : unit = Mappings.push solver
let pop (solver : t) (lvl : int) : unit = Mappings.pop solver lvl
let reset (solver : t) : unit = Mappings.reset solver
let add (solver : t) (es : Expr.t list) : unit = Mappings.add_solver solver es
let get_assertions (_solver : t) : Expr.t list = assert false

let check (solver : t) (es : Expr.t list) : bool =
solver_count := !solver_count + 1;
let sat = time_call (fun () -> Mappings.check solver es) solver_time in
match Mappings.satisfiability sat with
| Mappings_intf.Satisfiable -> true
| Mappings_intf.Unknown -> raise Unknown
| Mappings_intf.Unsatisfiable -> false

let get_value (solver : t) (e : Expr.t) : Expr.t =
match Mappings.solver_model solver with
| Some m -> Expr.(Val (Mappings.value m e) @: e.ty)
| None -> assert false

let model ?(symbols : Symbol.t list option) (solver : t) : Model.t Option.t =
let+ model = Mappings.solver_model solver in
Mappings.values_of_model ?symbols model
end

module Batch' (M : Mappings_intf.S) : Solver_intf.S = Batch (M)
module Incremental' (M : Mappings_intf.S) : Solver_intf.S = Incremental (M)
2 changes: 1 addition & 1 deletion test/test_axiom.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
open Encoding
module Batch = Batch.Make (Z3_mappings)
module Batch = Solver.Batch (Z3_mappings)

let solver = Batch.create ()

Expand Down
2 changes: 1 addition & 1 deletion test/test_batch.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open Encoding
open Ty
open Expr
module Batch = Batch.Make (Z3_mappings)
module Batch = Solver.Batch (Z3_mappings)

let solver = Batch.create ()
let symb_x = Symbol.("x" @: Ty_int)
Expand Down
2 changes: 1 addition & 1 deletion test/test_bool.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open Encoding
open Ty
open Expr
module Batch = Batch.Make (Z3_mappings)
module Batch = Solver.Batch (Z3_mappings)

let solver = Batch.create ()
let symb_x = Symbol.("x" @: Ty_bool)
Expand Down
2 changes: 1 addition & 1 deletion test/test_f32.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open Encoding
open Ty
open Expr
module Batch = Batch.Make (Z3_mappings)
module Batch = Solver.Batch (Z3_mappings)

let solver = Batch.create ()
let x = Expr.mk_symbol Symbol.("x" @: Ty_fp S32)
Expand Down
2 changes: 1 addition & 1 deletion test/test_int.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open Encoding
open Ty
open Expr
module Batch = Batch.Make (Z3_mappings)
module Batch = Solver.Batch (Z3_mappings)

let solver = Batch.create ()
let one = Val (Int 1) @: Ty_int
Expand Down
2 changes: 1 addition & 1 deletion test/test_params.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
open Encoding
module Batch = Batch.Make (Z3_mappings)
module Batch = Solver.Batch (Z3_mappings)

let%test_unit _ =
let params = Params.(default () & (Ematching, false)) in
Expand Down
2 changes: 1 addition & 1 deletion test/test_parser.ml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
open Encoding
module Batch = Batch.Make (Z3_mappings)
module Batch = Solver.Batch (Z3_mappings)
module Interpret = Interpret.Make (Batch)

let parse script = Run.parse_string script
Expand Down
2 changes: 1 addition & 1 deletion test/test_real.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open Encoding
open Ty
open Expr
module Batch = Batch.Make (Z3_mappings)
module Batch = Solver.Batch (Z3_mappings)

let solver = Batch.create ()

Expand Down
2 changes: 1 addition & 1 deletion test/test_regression.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open Encoding
open Ty
open Expr
module Batch = Batch.Make (Z3_mappings)
module Batch = Solver.Batch (Z3_mappings)

let solver = Batch.create ()

Expand Down
2 changes: 1 addition & 1 deletion test/test_str.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open Encoding
open Ty
open Expr
module Batch = Batch.Make (Z3_mappings)
module Batch = Solver.Batch (Z3_mappings)

let solver = Batch.create ()
let abc = Val (Str "abc") @: Ty_str
Expand Down

0 comments on commit 5beceda

Please sign in to comment.