Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 31 additions & 51 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from funsor.delta import Delta
from funsor.domains import find_domain
from funsor.gaussian import Gaussian
from funsor.interpretations import eager, normalize, reflect
from funsor.interpretations import eager, normalize, reflect, simplify
from funsor.interpreter import children, recursion_reinterpret
from funsor.ops import DISTRIBUTIVE_OPS, AssociativeOp, NullOp, null
from funsor.tensor import Tensor
Expand Down Expand Up @@ -255,12 +255,14 @@ def children_contraction(x):
return (x.red_op, x.bin_op, x.reduced_vars) + x.terms


@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor])
@simplify.register(
Contraction, AssociativeOp, AssociativeOp, frozenset, Variadic[Funsor]
)
def eager_contraction_generic_to_tuple(red_op, bin_op, reduced_vars, *terms):
return eager.interpret(Contraction, red_op, bin_op, reduced_vars, terms)
return Contraction(red_op, bin_op, reduced_vars, terms)


@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
@simplify.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms):
# Count the number of terms in which each variable is reduced.
counts = Counter()
Expand All @@ -276,9 +278,7 @@ def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms):
unique_vars = reduced_once & term.input_vars
if unique_vars:
result = term.reduce(red_op, unique_vars)
if result is not normalize.interpret(
Contraction, red_op, null, unique_vars, (term,)
):
if result is not normalize(term.reduce)(red_op, unique_vars):
terms[i] = result
reduced_vars -= unique_vars
leaf_reduced = True
Expand All @@ -294,9 +294,9 @@ def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms):
j = i + j_ + 1
unique_vars = reduced_twice.intersection(lhs.input_vars, rhs.input_vars)
result = Contraction(red_op, bin_op, unique_vars, lhs, rhs)
if result is not normalize.interpret(
Contraction, red_op, bin_op, unique_vars, (lhs, rhs)
): # did we make progress?
with normalize:
nr = Contraction(red_op, bin_op, unique_vars, lhs, rhs)
if result is not nr:
# pick the first evaluable pair
reduced_vars -= unique_vars
new_terms = terms[:i] + (result,) + terms[i + 1 : j] + terms[j + 1 :]
Expand All @@ -305,27 +305,28 @@ def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms):
return None


@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor)
@simplify.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor)
def eager_contraction_to_reduce(red_op, bin_op, reduced_vars, term):
args = red_op, term, reduced_vars
return eager.dispatch(Reduce, *args)(*args)
return term.reduce(red_op, reduced_vars)


@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor)
@simplify.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor)
def eager_contraction_to_binary(red_op, bin_op, reduced_vars, lhs, rhs):
return bin_op(lhs, rhs).reduce(red_op, reduced_vars)

if not reduced_vars.issubset(lhs.input_vars & rhs.input_vars):
args = red_op, bin_op, reduced_vars, (lhs, rhs)
result = eager.dispatch(Contraction, *args)(*args)
if result is not None:
return result

args = bin_op, lhs, rhs
result = eager.dispatch(Binary, *args)(*args)
if result is not None and reduced_vars:
args = red_op, result, reduced_vars
result = eager.dispatch(Reduce, *args)(*args)
return result
@simplify.register(Binary, AssociativeOp, (Number, Funsor, Align), Number)
def eager_eliminate_unit(op, lhs, rhs):
if op in ops.UNITS and rhs.data == ops.UNITS[op]:
return lhs
return None


@simplify.register(Binary, AssociativeOp, Number, (Align, Funsor))
def eager_eliminate_unit(op, lhs, rhs):
if op in ops.UNITS and lhs.data == ops.UNITS[op]:
return rhs
return None


@eager.register(Contraction, ops.AddOp, ops.MulOp, frozenset, Tensor, Tensor)
Expand Down Expand Up @@ -458,7 +459,7 @@ def normalize_contraction_generic_args(red_op, bin_op, reduced_vars, *terms):
return normalize.interpret(Contraction, red_op, bin_op, reduced_vars, tuple(terms))


@normalize.register(Contraction, NullOp, NullOp, frozenset, Funsor)
@simplify.register(Contraction, NullOp, NullOp, frozenset, Funsor)
def normalize_trivial(red_op, bin_op, reduced_vars, term):
assert not reduced_vars
return term
Expand All @@ -480,18 +481,6 @@ def normalize_contraction_generic_tuple(red_op, bin_op, reduced_vars, terms):
new_terms = tuple(v.reduce(red_op, reduced_vars) for v in terms)
return Contraction(red_op, bin_op, frozenset(), *new_terms)

if bin_op in ops.UNITS and any(
isinstance(t, Number) and t.data == ops.UNITS[bin_op] for t in terms
):
new_terms = tuple(
t
for t in terms
if not (isinstance(t, Number) and t.data == ops.UNITS[bin_op])
)
if not new_terms: # everything was a unit
new_terms = (terms[0],)
return Contraction(red_op, bin_op, reduced_vars, *new_terms)

for i, v in enumerate(terms):

if not isinstance(v, Contraction):
Expand Down Expand Up @@ -541,15 +530,6 @@ def unary_neg_variable(op, arg):
#######################################################################


@normalize.register(Subs, Funsor, tuple)
def do_fresh_subs(arg, subs):
if not subs:
return arg
if all(name in arg.fresh for name, sub in subs):
return arg.eager_subs(subs)
return None


@normalize.register(Subs, Contraction, tuple)
def distribute_subs_contraction(arg, subs):
new_terms = tuple(
Expand Down Expand Up @@ -581,10 +561,10 @@ def binary_divide(op, lhs, rhs):
return lhs * Unary(ops.reciprocal, rhs)


@normalize.register(Unary, ops.ExpOp, Unary[ops.LogOp, Funsor])
@normalize.register(Unary, ops.LogOp, Unary[ops.ExpOp, Funsor])
@normalize.register(Unary, ops.NegOp, Unary[ops.NegOp, Funsor])
@normalize.register(Unary, ops.ReciprocalOp, Unary[ops.ReciprocalOp, Funsor])
@simplify.register(Unary, ops.ExpOp, Unary[ops.LogOp, Funsor])
@simplify.register(Unary, ops.LogOp, Unary[ops.ExpOp, Funsor])
@simplify.register(Unary, ops.NegOp, Unary[ops.NegOp, Funsor])
@simplify.register(Unary, ops.ReciprocalOp, Unary[ops.ReciprocalOp, Funsor])
def unary_log_exp(op, arg):
return arg.arg

Expand Down
5 changes: 3 additions & 2 deletions funsor/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import makefun

from funsor.instrument import debug_logged
from funsor.terms import Funsor, FunsorMeta, Variable, eager, to_funsor
from funsor.interpretations import simplify
from funsor.terms import Funsor, FunsorMeta, Variable, to_funsor


def _erase_types(fn):
Expand Down Expand Up @@ -257,7 +258,7 @@ def _alpha_convert(self, alpha_subs):
pattern = (Result,) + tuple(
_hint_to_pattern(input_types[k]) for k in Result._ast_fields
)
eager.register(*pattern)(_erase_types(fn))
simplify.register(*pattern)(_erase_types(fn))
return Result


Expand Down
36 changes: 17 additions & 19 deletions funsor/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from funsor.cnf import Contraction, GaussianMixture
from funsor.delta import Delta
from funsor.gaussian import Gaussian, _mv, _trace_mm, _vv, align_gaussian
from funsor.interpretations import eager, normalize
from funsor.interpretations import eager, normalize, simplify
from funsor.tensor import Tensor
from funsor.terms import (
Funsor,
Expand Down Expand Up @@ -88,12 +88,22 @@ def normalize_integrate(log_measure, integrand, reduced_vars):


@normalize.register(
Unary, ops.ExpOp, Contraction[ops.NullOp, ops.AddOp, frozenset, tuple]
)
def normalize_expadd(op, arg):
# TODO guarantee this is lazy - want no direct exponentiation of ground values
return Contraction(arg.red_op, ops.mul, arg.reduced_vars, *map(op, arg.terms))


@simplify.register(
Integrate,
Contraction[Union[ops.NullOp, ops.LogaddexpOp], ops.AddOp, frozenset, tuple],
Funsor,
frozenset,
)
def normalize_integrate_contraction(log_measure, integrand, reduced_vars):
def simplify_integrate_contraction(log_measure, integrand, reduced_vars):
# TODO ensure this rule fires - delta points must be substituted correctly
# TODO should this be part of normalize or simplify? probably normalize...
reduced_names = frozenset(v.name for v in reduced_vars)
delta_terms = [
t
Expand All @@ -109,34 +119,22 @@ def normalize_integrate_contraction(log_measure, integrand, reduced_vars):
if name in reduced_names.intersection(integrand.inputs)
}
)
return normalize_integrate(log_measure, integrand, reduced_vars)
return Integrate(log_measure, integrand, reduced_vars)


@eager.register(
@simplify.register(
Contraction,
ops.AddOp,
ops.MulOp,
frozenset,
Unary[ops.ExpOp, Union[GaussianMixture, Delta, Gaussian, Number, Tensor]],
(Variable, Delta, Gaussian, Number, Tensor, GaussianMixture),
Funsor, # TODO is this too broad?
)
def eager_contraction_binary_to_integrate(red_op, bin_op, reduced_vars, lhs, rhs):
reduced_names = frozenset(v.name for v in reduced_vars)
if not (reduced_names.issubset(lhs.inputs) and reduced_names.issubset(rhs.inputs)):
args = red_op, bin_op, reduced_vars, (lhs, rhs)
result = eager.dispatch(Contraction, *args)(*args)
if result is not None:
return result

args = lhs.log(), rhs, reduced_vars
result = eager.dispatch(Integrate, *args)(*args)
if result is not None:
return result

return None
return Integrate(lhs.log(), rhs, reduced_vars)


@eager.register(Integrate, GaussianMixture, Funsor, frozenset)
@simplify.register(Integrate, GaussianMixture, Funsor, frozenset)
def eager_integrate_gaussianmixture(log_measure, integrand, reduced_vars):
real_vars = frozenset(v for v in reduced_vars if v.dtype == "real")
if reduced_vars <= real_vars:
Expand Down
85 changes: 72 additions & 13 deletions funsor/interpretations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
from timeit import default_timer

from . import instrument
from .interpreter import get_interpretation, pop_interpretation, push_interpretation
from .interpreter import (
get_interpretation,
is_atom,
pop_interpretation,
push_interpretation,
reinterpret,
)
from .registry import KeyedRegistry
from .util import get_backend

Expand Down Expand Up @@ -259,6 +265,45 @@ def register(cls, *args):
return cls.registry.register(*args)


class NormalizedInterpretation(Interpretation):
def __init__(self, subinterpretation):
super().__init__(f"Normalized({subinterpretation.__name__})")
self.subinterpretation = subinterpretation
self.register = self.subinterpretation.register
self.dispatch = self.subinterpretation.dispatch
self._cache = {} # weakref.WeakValueDictionary() # TODO make this work

def interpret(self, cls, *args):
# 1. try self.subinterpret.
result = self.subinterpretation.interpret(cls, *args)
if result is not None:
return result

# 2. normalize to a Contraction normal form (will succeed)
# Note eager_contraction_generic_recursive() effectively fuses this
# step with step 3 below to short-circuit some logic.
with normalize:
normalized_args = []
for arg in args:
try:
normalized_args.append(arg if is_atom(arg) else self._cache[arg])
except KeyError:
normalized_arg = reinterpret(arg)
self._cache[arg] = normalized_arg
normalized_args.append(normalized_arg)
normal_form = cls(*normalized_args)

# 3. try evaluating that normal form
with PrioritizedInterpretation(self.subinterpretation, simplify):
# TODO use .interpret instead of reinterpret here to avoid traversal
result = reinterpret(normal_form)
if result is not normal_form: # I.e. was progress made?
return result

# 4. if that fails, fall back to base interpretation of cls(*args)
return None


class Memoize(Interpretation):
"""
Exploits cons-hashing to do implicit common subexpression elimination.
Expand Down Expand Up @@ -303,6 +348,18 @@ def memoize(cache=None):
# Concrete interpretations.


class Simplify(DispatchedInterpretation):

is_total = True # because it always ends with normalize

def interpret(self, cls, *args):
result = super().interpret(cls, *args)
if result is None:
with normalize:
result = cls(*args)
return result


@CallableInterpretation
def reflect(cls, *args):
raise ValueError("Should be overwritten in terms.py")
Expand All @@ -317,35 +374,34 @@ def reflect(cls, *args):
numerical operations.
"""

lazy_base = DispatchedInterpretation("lazy")
simplify = Simplify("simplify")

lazy_base = NormalizedInterpretation(DispatchedInterpretation("lazy"))
lazy = PrioritizedInterpretation(lazy_base, reflect)
"""
Performs substitutions eagerly, but construct lazy funsors for everything else.
"""

eager_base = DispatchedInterpretation("eager")
eager = PrioritizedInterpretation(eager_base, normalize_base, reflect)
eager_base = NormalizedInterpretation(DispatchedInterpretation("eager"))
eager = PrioritizedInterpretation(eager_base, reflect)
"""
Eager exact naive interpretation wherever possible.
"""

die = DispatchedInterpretation("die")
eager_or_die = PrioritizedInterpretation(eager_base, die, reflect)
eager_or_die = PrioritizedInterpretation(eager_base.subinterpretation, die, reflect)

sequential_base = DispatchedInterpretation("sequential")
# XXX does this work with sphinx/help()?
sequential = PrioritizedInterpretation(
sequential_base, eager_base, normalize_base, reflect
)
sequential_base = NormalizedInterpretation(DispatchedInterpretation("sequential"))
sequential = PrioritizedInterpretation(sequential_base, eager)
"""
Eagerly execute ops with known implementations; additonally execute
vectorized ops sequentially if no known vectorized implementation exists.
"""

moment_matching_base = DispatchedInterpretation("moment_matching")
moment_matching = PrioritizedInterpretation(
moment_matching_base, eager_base, normalize_base, reflect
moment_matching_base = NormalizedInterpretation(
DispatchedInterpretation("moment_matching")
)
moment_matching = PrioritizedInterpretation(moment_matching_base, eager)
"""
A moment matching interpretation of :class:`Reduce` expressions. This falls
back to :class:`eager` in other cases.
Expand All @@ -359,6 +415,8 @@ def reflect(cls, *args):
"CallableInterpretation",
"DispatchedInterpretation",
"Interpretation",
"NormalizedInterpretation",
"PrioritizedInterpretation",
"Memoize",
"StatefulInterpretation",
"die",
Expand All @@ -369,4 +427,5 @@ def reflect(cls, *args):
"normalize",
"reflect",
"sequential",
"simplify",
]
Loading