Skip to content

Commit

Permalink
Make BinRel less abstract
Browse files Browse the repository at this point in the history
  • Loading branch information
bclement-ocp committed Aug 2, 2024
1 parent a5f6482 commit 2462f71
Showing 1 changed file with 51 additions and 98 deletions.
149 changes: 51 additions & 98 deletions src/lib/reasoners/rel_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -531,45 +531,34 @@ struct
end


module BinRel(X : OrderedType)(W : OrderedType) : sig
(** This module provides a thin abstraction to keep track of binary relations
between values of two different types. *)
module WatchMap(X : OrderedType)(W : OrderedType) : sig
(** This module provides a thin abstraction to keep track of
"watchers" associated to a given variable.
It allows finding all the watches associated to a variable, and
conversely of all the variable associated with a watch. *)

type t
(** The type of binary relations between [X.t] and [W.t]. *)
(** The type of maps from variables [X.t] to watches [W.t]. *)

val empty : t
(** The empty relation. *)

val add : X.t -> W.t -> t -> t
(** [add x w r] adds the tuple [(x, w)] to the relation. *)

val add_many : X.t -> W.Set.t -> t -> t
(** [add_many x ws t] adds the tuples [(x, w)] for each [w] in [ws]. *)

val range : X.t -> t -> W.Set.t
(** [range x t] returns all the watches [w] such that [(x, w)] is in the
relation. *)

val remove_dom : X.t -> t -> t
(** [remove_dom x r] removes all tuples of the form [(x, _)] from the
relation. *)

val remove_range : W.t -> t -> t
(** [remove_range w r] removes all tuples of the form [(_, w)] from the
relation. *)
val add_watches : X.t -> W.Set.t -> t -> t
(** [add_watches x ws t] associates all of the watches in [ws] to the
variable [x]. *)

val transfer_dom : X.t -> X.t -> t -> t
(** [transfer_dom x x' r] replaces all tuples of the form [(x, w)] in the
relation with the corresponding [(x', w)] tuple. *)
val watches : X.t -> t -> W.Set.t
(** [watches x t] returns all the watches associated to [x]. *)

val iter_range : X.t -> (W.t -> unit) -> t -> unit
(** [iter_range x f r] calls [f] on all the [w] such that [(x, w)] is in the
relation. *)
val take_watches : X.t -> t -> W.Set.t * t
(** [take_watches x t] returns a pair [ws, t'] where [ws] is the set of
watches associated with [x] in [t], and [t'] is identical to [t]
except that no watches are associated to [x]. *)

val fold_range : X.t -> (W.t -> 'a -> 'a) -> t -> 'a -> 'a
(** [fold_range x f r acc] folds [f] over all the [w] such that [(x, w)] is in
the relation.*)
val remove_watch : W.t -> t -> t
(** [remove_watch w t] removes [w] from [t] so that it is no longer
associated to any variable. *)
end = struct
module MX = X.Map
module MW = W.Map
Expand All @@ -586,26 +575,11 @@ end = struct
remove watches. *)
}

let range x t =
try MX.find x t.watches with Not_found -> W.Set.empty

let empty =
{ watches = MX.empty
; watching = MW.empty }

let add x w t =
let watches =
MX.update x (function
| None -> Some (SW.singleton w)
| Some ws -> Some (SW.add w ws)) t.watches
and watching =
MW.update w (function
| None -> Some (SX.singleton x)
| Some xs -> Some (SX.add x xs)) t.watching
in
{ watches ; watching }

let add_many x ws t =
let add_watches x ws t =
let watches =
MX.update x (function
| None -> Some ws
Expand All @@ -619,7 +593,7 @@ end = struct
in
{ watches ; watching }

let remove_range w t =
let remove_watch w t =
match MW.find w t.watching with
| xs ->
let watches =
Expand All @@ -638,7 +612,10 @@ end = struct
{ watches ; watching }
| exception Not_found -> t

let remove_dom x t =
let watches x t =
try MX.find x t.watches with Not_found -> W.Set.empty

let take_watches x t =
match MX.find x t.watches with
| ws ->
let watching =
Expand All @@ -653,40 +630,8 @@ end = struct
) watching
) ws t.watching
and watches = MX.remove x t.watches in
{ watches ; watching }
| exception Not_found -> t

let fold_range x f t acc =
match MX.find x t.watches with
| ws -> SW.fold f ws acc
| exception Not_found -> acc

let iter_range x f t =
match MX.find x t.watches with
| ws -> SW.iter f ws
| exception Not_found -> ()

let transfer_dom x x' t =
match MX.find x t.watches with
| ws ->
let watching =
SW.fold (fun w watching ->
MW.update w (function
| None ->
(* maps must be mutual inverses *)
assert false
| Some xs ->
Some (SX.add x' (SX.remove x xs))
) watching
) ws t.watching
and watches =
MX.update x' (function
| None -> Some ws
| Some ws' -> Some (SW.union ws ws')
) (MX.remove x t.watches)
in
{ watches ; watching }
| exception Not_found -> t
ws, { watches ; watching }
| exception Not_found -> W.Set.empty, t
end

(** Implementation of the [ComparableType] interface for semantic values. *)
Expand Down Expand Up @@ -833,8 +778,8 @@ struct
module DMA = DomainMap(A)(D)
module DMC = DomainMap(C)(D)

module AW = BinRel(A)(W)
module CW = BinRel(C)(W)
module AW = WatchMap(A)(W)
module CW = WatchMap(C)(W)

type t =
{ atoms : DMA.t
Expand Down Expand Up @@ -882,16 +827,20 @@ struct
match NF.normal_form r with
| Constant _ -> t
| Atom (a, _) ->
{ t with atom_watches = AW.add a w t.atom_watches }
{ t with
atom_watches =
AW.add_watches a (W.Set.singleton w) t.atom_watches }
| Composite (c, _) ->
{ t with composite_watches = CW.add c w t.composite_watches }
{ t with
composite_watches =
CW.add_watches c (W.Set.singleton w) t.composite_watches }

let unwatch w t =
{ atoms = t.atoms
; atom_watches = AW.remove_range w t.atom_watches
; atom_watches = AW.remove_watch w t.atom_watches
; variables = t.variables
; composites = t.composites
; composite_watches = CW.remove_range w t.composite_watches
; composite_watches = CW.remove_watch w t.composite_watches
; parents = t.parents
; triggers = t.triggers }

Expand Down Expand Up @@ -959,19 +908,21 @@ struct
| Constant _ -> invalid_arg "subst: cannot substitute a constant"
| Atom (a, o) ->
let variables = A.Set.remove a t.variables in
let ws, atom_watches = AW.take_watches a t.atom_watches in
D.add_offset (find_or_default_atom a t) o,
AW.range a t.atom_watches,
ws,
{ t with
atoms = DMA.remove a t.atoms ;
atom_watches = AW.remove_dom a t.atom_watches ;
atom_watches ;
variables }
| Composite (c, o) ->
let parents = untrack c t.parents in
let ws, composite_watches = CW.take_watches c t.composite_watches in
D.add_offset (find_or_default_composite c t) o,
CW.range c t.composite_watches,
ws,
{ t with
composites = DMC.remove c t.composites ;
composite_watches = CW.remove_dom c t.composite_watches ;
composite_watches ;
parents }
in
(* Add [ex] to justify that it applies to [nrr] *)
Expand All @@ -985,11 +936,11 @@ struct
match nrr_nf with
| Constant _ -> t
| Atom (a, _) ->
let atom_watches = AW.add_many a ws t.atom_watches in
let atom_watches = AW.add_watches a ws t.atom_watches in
let variables = A.Set.add a t.variables in
{ t with atom_watches ; variables }
| Composite (c, _) ->
let composite_watches = CW.add_many c ws t.composite_watches in
let composite_watches = CW.add_watches c ws t.composite_watches in
let parents = track c t.parents in
{ t with composite_watches ; parents }
in
Expand All @@ -1001,12 +952,14 @@ struct
shrunk it, it can only be empty. *)
assert false
| Atom (a, o) ->
let triggers = W.Set.union (AW.range a t.atom_watches) t.triggers in
let triggers =
W.Set.union (AW.watches a t.atom_watches) t.triggers
in
let atoms = DMA.add a (D.sub_offset nnrrd o) t.atoms in
{ t with atoms ; triggers }
| Composite (c, o) ->
let triggers =
W.Set.union (CW.range c t.composite_watches) t.triggers
W.Set.union (CW.watches c t.composite_watches) t.triggers
in
let composites = DMC.add c (D.sub_offset nnrrd o) t.composites in
{ t with composites ; triggers }
Expand Down Expand Up @@ -1117,10 +1070,10 @@ struct

let notify_atom a =
events.evt_atomic_change a;
AW.iter_range a events.evt_watch_trigger t.atom_watches
W.Set.iter events.evt_watch_trigger (AW.watches a t.atom_watches);
and notify_composite c =
events.evt_composite_change c;
CW.iter_range c events.evt_watch_trigger t.composite_watches
W.Set.iter events.evt_watch_trigger (CW.watches c t.composite_watches);
in

{ Ephemeral.atoms =
Expand Down

0 comments on commit 2462f71

Please sign in to comment.