Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def _min(x, y):

@ops.new_full.register(array)
def _new_full(x, shape, value):
return np.full(shape, value, dtype=np.result_type(x))
# TODO: revert this change
return np.full(shape, value).astype(dtype=np.result_type(x))


@ops.new_arange.register(array)
Expand Down
54 changes: 53 additions & 1 deletion funsor/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@
from funsor.interpretations import eager, moment_matching, normalize
from funsor.ops import AssociativeOp
from funsor.tensor import Tensor, align_tensor
from funsor.terms import Funsor, Independent, Number, Reduce, Unary
from funsor.terms import (
Funsor,
Independent,
Number,
Reduce,
Scatter,
Slice,
Unary,
Variable,
)
from funsor.typing import Variadic


Expand Down Expand Up @@ -192,3 +201,46 @@ def eager_independent_joint(joint, reals_var, bint_var, diag_var):
delta = Independent(joint.terms[0], reals_var, bint_var, diag_var)
new_terms = (delta,) + tuple(t.reduce(ops.add, bint_var) for t in joint.terms[1:])
return reduce(joint.bin_op, new_terms)


@eager.register(
Scatter,
AssociativeOp,
Tuple[Tuple[str, Union[Slice, Variable, Number]], ...],
Contraction[
AssociativeOp, AssociativeOp, frozenset, Tuple[Delta, Union[Tensor, Number]]
],
frozenset,
)
def eager_scatter_slice_contraction(op, subs, source, reduced_vars):
new_terms = []
for term in source.terms:
new_terms.append(Scatter(op, subs, term, frozenset()))
new_terms = tuple(new_terms)
return source.bin_op(*new_terms).reduce(
source.red_op, source.reduced_vars | reduced_vars
)


@eager.register(
Scatter,
AssociativeOp,
Tuple[Tuple[str, Union[Slice, Variable, Number]], ...],
Delta,
frozenset,
)
def eager_scatter_slice_delta(op, subs, source, reduced_vars):
new_terms = []
for name, (point, log_density) in source.terms:
new_terms.append(
(
name,
(
Scatter(op, subs, point, frozenset()), # This needs to fix
Scatter(op, subs, log_density, frozenset()),
),
)
)
new_terms = tuple(new_terms)
result = Delta(new_terms)
return result.reduce(op, reduced_vars)
5 changes: 5 additions & 0 deletions test/test_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,8 @@ def test_reduce_moment_matching_finite():
joint = delta + discrete + gaussian
with moment_matching:
joint.reduce(ops.logaddexp, reduced_vars)


def test_scatter_slice_delta():
# TODO: add test
pass