Skip to content

Conversation

@ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Mar 17, 2021

This is an implementation of autodiff. The goal is to address issues in computing expectations in TraceEnum_ELBO and TraceMarkovEnum_ELBO (#493). As of now it seems to fix nan gradients under eager interpretation in TraceEnum_ELBO.

The algorithm implements equivalents of linearize(), transpose() functions, and is tape-free (#446).

  1. Linearize. Variables that need to be linearized are replaced by primal- tangent tuple JVP(primal, tangent) and then pattern matched to propagate tangents, e.g.:
JVP(x, dx) + JVP(y, dy) = JVP(x+y, dx+dy)
JVP(x, dx) * JVP(y, dy) = JVP(x*y, ydx + xdy)
JVP(x, dx) * y = JVP(x*y, ydx)

Out tangent is a linear function of in tangents. JVP is used for (add,mul) semiring and LJVP is used for (logaddexp,add) semiring.

  1. Transpose of a linear function. Transpose is implemented simply by inverting the order of function execution and transposing matrices, in this case swapping more primitive operations .reduce(sum_op, "i") and .expand("i") (broadcasting does this automatically).

@ordabayevy ordabayevy added the WIP label Mar 17, 2021
inputs = OrderedDict([(var.name, var.output) for var in expanded_vars])
inputs.update(arg.inputs)
output = arg.output
fresh = frozenset()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be

fresh = frozenset(v.name for v in expanded_vars)

from funsor.terms import Binary, Funsor, Lambda, Number, Tuple, Variable, lazy
from funsor.testing import assert_close, random_tensor

funsor.set_backend("torch")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test files should read but not write the global backend. Instead you can decorate each test with

@pytest.mark.skipif(get_backend() != "torch", reason="backend-specific")

and then run tests with

FUNSOR_BACKEND=torch pytest test/test_autodiff.py

raise NotImplementedError(f"Missing pattern for {repr(expr)}")


class Expand(Funsor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I'd like to better understand the need for this.

We've been trying to preserve the extensionality property in Funsor, which states that: if under every grounding substitution subs a pair of funsors f,g satisfy f(**subs) == g(**subs), then it should be permissible for an optimizer to replace funsor f with funsor g in any expression. IIUC this Expand funsor would break extensionality because f.expand(...) behaves as f under every grounding substitution.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants