Skip to content
282 changes: 282 additions & 0 deletions funsor/autodiff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import math
from collections import defaultdict
from functools import reduce, singledispatch

import funsor.ops as ops
from funsor import Tensor
from funsor.adjoint import _alpha_unmangle
from funsor.cnf import Contraction
from funsor.domains import Array, Bint, Real, Reals
from funsor.interpretations import autodiff, trace
from funsor.interpreter import interpretation
from funsor.ops import AssociativeOp, LogOp
from funsor.terms import (
Binary,
Funsor,
Lambda,
Number,
Reduce,
Tuple,
Unary,
Variable,
eager,
lazy,
)


class JVP(Tuple):
"""
Tuple:(Primal, Tanget)
Semiring: (Add, Mul)
"""

sum_op = ops.add
prod_op = ops.mul
div_op = ops.safediv
zero = Number(0)
one = Number(1)

@property
def primal(self):
return self[0]

@property
def tangent(self):
return self[1]


class LJVP(Tuple):
"""
Tuple: (LogPrimal, LogTanget)
Semiring: (Logaddexp, Add)
"""

sum_op = ops.logaddexp
prod_op = ops.add
div_op = ops.safesub
zero = Number(-math.inf)
one = Number(0)

@property
def primal(self):
return self[0]

@property
def tangent(self):
return self[1]


@trace.register(Binary, AssociativeOp, Funsor, Funsor)
def trace_binary_associativeop(op, lhs, rhs):
with lazy:
result = Binary(op, lhs, rhs)
return result


@trace.register(Reduce, AssociativeOp, Funsor, frozenset)
def trace_binary_associativeop(op, arg, reduced_args):
with lazy:
result = Reduce(op, arg, reduced_args)
return result


def to_jvp(primal):
input_vars = tuple(Variable(key, value) for key, value in primal.inputs.items())
output = reduce(lambda x, y: Lambda(y, x), reversed(input_vars), primal).output
tangent_placeholder = Variable(str(id(primal)), output)[tuple(primal.inputs)]
return JVP(primal, tangent_placeholder)


def to_ljvp(primal):
input_vars = tuple(Variable(key, value) for key, value in primal.inputs.items())
output = reduce(lambda x, y: Lambda(y, x), reversed(input_vars), primal).output
tangent_placeholder = Variable(str(id(primal)), output)[tuple(primal.inputs)]
return LJVP(primal, tangent_placeholder)


def grad(expr, targets, out_tangent=None):
out_tangent = expr.one if out_tangent is None else out_tangent
in_tangents = set(target.tangent for target in targets)
transposes = transpose(
expr.tangent,
out_tangent,
in_tangents,
defaultdict(lambda: expr.zero),
type(expr),
)
result = {}
for target in targets:
result[target] = transposes[target.tangent]
return result


@singledispatch
def transpose(expr, out_tangent, in_tangents, result, semiring):
if expr in in_tangents:
result[expr] = semiring.sum_op(result[expr], out_tangent)
return result


@transpose.register(Binary)
def transpose_binary(expr, out_tangent, in_tangents, result, semiring):

op, lhs, rhs = expr.op, expr.lhs, expr.rhs
sum_op, prod_op = semiring.sum_op, semiring.prod_op

if expr in in_tangents:
result[expr] = sum_op(result[expr], out_tangent)
out_tangent = result[expr]

if op is sum_op:
lhs_adj = out_tangent.reduce(sum_op, out_tangent.input_vars - lhs.input_vars)
rhs_adj = out_tangent.reduce(sum_op, out_tangent.input_vars - rhs.input_vars)
elif op is prod_op:
lhs_adj = prod_op(rhs, out_tangent).reduce(
sum_op, out_tangent.input_vars - lhs.input_vars
)
rhs_adj = prod_op(lhs, out_tangent).reduce(
sum_op, out_tangent.input_vars - rhs.input_vars
)
else:
return result # is it always correct?
result = transpose(lhs, lhs_adj, in_tangents, result, semiring)
result = transpose(rhs, rhs_adj, in_tangents, result, semiring)
return result


@transpose.register(Reduce)
def transpose_reduce(expr, out_tangent, in_tangents, result, semiring):
# fix this in contraction as well
op, arg, reduced_vars = _alpha_unmangle(expr)
sum_op, prod_op = semiring.sum_op, semiring.prod_op

if expr in in_tangents:
result[expr] = sum_op(result[expr], out_tangent)
out_tangent = result[expr]

if op is sum_op:
arg_adj = out_tangent.expand(tuple(reduced_vars))
result = transpose(arg, arg_adj, in_tangents, result, semiring)
return result
elif op is prod_op:
# this is unnecessary
return result
else:
raise ValueError


@transpose.register(Contraction)
def transpose_contraction(expr, out_tangent, in_tangents, result):
breakpoint()
if expr in in_tangents:
result[expr] += out_tangent
out_tangent = result[expr]

if expr.red_op is ops.nullop:
for term in expr.terms:
if expr.bin_op is ops.add:
term_adj = out_tangent.reduce(
ops.add, out_tangent.input_vars - term.input_vars
)
elif expr.bin_op is ops.mul:
expr_div_term = reduce(
ops.mul, tuple(t for t in expr.terms if t is not term)
)
term_adj = (out_tangent * expr_div_term).reduce(
ops.add, out_tangent.input_vars - term.input_vars
)
else:
raise ValueError
result = transpose(term, term_adj, in_tangents, result)
elif expr.bin_op is ops.nullop:
for term in expr.terms: # only one term
if expr.red_op is ops.add:
term_adj = out_tangent.expand(tuple(expr.reduced_vars))
elif expr.red_op is ops.mul:
term_adj = ops.safediv(ops.mul(out_tangent, expr), term)
else:
raise ValueError
result = transpose(term, term_adj, in_tangents, result)
else:
raise ValueError
return result


@eager.register(Binary, AssociativeOp, JVP, JVP)
@eager.register(Binary, AssociativeOp, LJVP, LJVP)
@autodiff.register(Binary, AssociativeOp, JVP, JVP)
@autodiff.register(Binary, AssociativeOp, LJVP, LJVP)
def jvp_binary(op, lhs, rhs):
sum_op, prod_op = lhs.sum_op, lhs.prod_op
primal = Binary(op, lhs.primal, rhs.primal)
if op is sum_op:
tangent = sum_op(lhs.tangent, rhs.tangent)
elif op is prod_op:
tangent = sum_op(
prod_op(rhs.primal, lhs.tangent), prod_op(lhs.primal, rhs.tangent)
)
else:
raise NotImplementedError
return type(lhs)(primal, tangent)


@eager.register(Binary, AssociativeOp, JVP, (Number, Tensor))
@eager.register(Binary, AssociativeOp, LJVP, (Number, Tensor))
@autodiff.register(Binary, AssociativeOp, JVP, (Number, Tensor))
@autodiff.register(Binary, AssociativeOp, LJVP, (Number, Tensor))
def jvp_binary_jvp_funsor(op, lhs, rhs):
sum_op, prod_op = lhs.sum_op, lhs.prod_op
primal = Binary(op, lhs.primal, rhs)
if op is sum_op:
tangent = sum_op(lhs.tangent, rhs)
elif op is prod_op:
tangent = prod_op(lhs.tangent, rhs)
else:
raise NotImplementedError
return type(lhs)(primal, tangent)


@eager.register(Binary, AssociativeOp, (Number, Tensor), JVP)
@eager.register(Binary, AssociativeOp, (Number, Tensor), LJVP)
@autodiff.register(Binary, AssociativeOp, (Number, Tensor), JVP)
@autodiff.register(Binary, AssociativeOp, (Number, Tensor), LJVP)
def jvp_binary_jvp_funsor(op, lhs, rhs):
sum_op, prod_op = rhs.sum_op, rhs.prod_op
primal = Binary(op, lhs, rhs.primal)
if op is sum_op:
tangent = sum_op(lhs, rhs.tangent)
elif op is prod_op:
tangent = prod_op(lhs, rhs.tangent)
else:
raise NotImplementedError
return type(rhs)(primal, tangent)


@eager.register(Reduce, AssociativeOp, JVP, frozenset)
@eager.register(Reduce, AssociativeOp, LJVP, frozenset)
@autodiff.register(Reduce, AssociativeOp, JVP, frozenset)
@autodiff.register(Reduce, AssociativeOp, LJVP, frozenset)
def jvp_reduce(op, arg, reduced_vars):
sum_op, prod_op, div_op = arg.sum_op, arg.prod_op, arg.div_op
primal = Reduce(op, arg.primal, reduced_vars)
if op is sum_op:
tangent = Reduce(sum_op, arg.tangent, reduced_vars)
elif op is prod_op:
tangent = Reduce(
sum_op, div_op(prod_op(arg.tangent, primal), arg.primal), reduced_vars
)
else:
raise NotImplementedError
return type(arg)(primal, tangent)


# @lazy.register(Unary, LogOp, JVP)
# @eager.register(Unary, LogOp, JVP)
# def jvp_log(op, arg):
# arg_primal, arg_tangent = arg
# primal = Unary(op, arg_primal)
# tangent = Binary(ops.truediv, arg_tangent, arg_primal)
# return JVP(primal, tangent)
6 changes: 5 additions & 1 deletion funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def _find_domain_getitem(op, lhs_domain, rhs_domain):
return Array[dtype, shape]
elif isinstance(lhs_domain, ProductDomain):
# XXX should this return a Union?
return Real
raise NotImplementedError(
"Cannot statically infer domain from: " f"{lhs_domain}[{rhs_domain}]"
)
Expand Down Expand Up @@ -325,7 +326,10 @@ def _find_domain_associative_generic(op, *domains):
return Array[domains[0].dtype, ()]

lhs, rhs = domains
if lhs.dtype == "real" or rhs.dtype == "real":
# FIXME
if lhs is rhs:
return lhs
elif lhs.dtype == "real" or rhs.dtype == "real":
dtype = "real"
elif op in (ops.add, ops.mul, ops.pow, ops.max, ops.min):
dtype = op(lhs.dtype - 1, rhs.dtype - 1) + 1
Expand Down
14 changes: 14 additions & 0 deletions funsor/interpretations.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,20 @@ def reflect(cls, *args):
Eager exact naive interpretation wherever possible.
"""

trace_base = DispatchedInterpretation("trace")
trace = PrioritizedInterpretation(trace_base, eager_base, normalize_base, reflect)
"""
Constructs a trace (expression) in terms of primitive operations.
"""

autodiff_base = DispatchedInterpretation("autodiff")
autodiff = PrioritizedInterpretation(
autodiff_base, trace_base, eager_base, normalize_base, reflect
)
"""
Constructs a trace (expression) in terms of primitive operations.
"""

die = DispatchedInterpretation("die")
eager_or_die = PrioritizedInterpretation(eager_base, die, reflect)

Expand Down
25 changes: 24 additions & 1 deletion funsor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from . import ops
from .delta import Delta
from .domains import Array, ArrayType, Bint, Product, Real, Reals, find_domain
from .ops import GetitemOp, MatmulOp, Op, ReshapeOp
from .ops import AssociativeOp, GetitemOp, MatmulOp, Op, ReshapeOp
from .terms import (
Binary,
Expand,
Finitary,
Funsor,
FunsorMeta,
Expand Down Expand Up @@ -682,6 +683,28 @@ def eager_scatter_tensor(op, subs, source, reduced_vars):
return Tensor(data, destin_inputs, output.dtype)


@eager.register(Expand, Number, tuple)
def eager_tensor_expand(arg, expanded_vars):
shape = tuple(var.output.size for var in expanded_vars)
inputs = OrderedDict([(var.name, var.output) for var in expanded_vars])
data = ops.new_full(
funsor.tensor.get_default_prototype(),
shape,
arg.data
)
return Tensor(data, inputs, arg.dtype)


@eager.register(Expand, Tensor, tuple)
def eager_tensor_expand(arg, expanded_vars):
expanded_shape = tuple(var.output.size for var in expanded_vars)
old_shape = (-1,) * (len(arg.inputs) + len(arg.output.shape))
new_shape = expanded_shape + old_shape
inputs = OrderedDict([(var.name, var.output) for var in expanded_vars])
inputs.update(arg.inputs)
return Tensor(ops.expand(arg.data, new_shape), inputs, arg.dtype)


@eager.register(Binary, Op, Tensor, Number)
def eager_binary_tensor_number(op, lhs, rhs):
dtype = find_domain(op, lhs.output, rhs.output).dtype
Expand Down
Loading