Standalone reverse autodiff library.
- Subclass
retrodiff.Function
to create base functions (or use presets fromretrodiff.utils.function
). - Make the dag function by composing
Node
s andFunction
s. - Set values on input nodes.
- Call
node.forward()
andnode.backward()
on the output node to calulate output values and gradients. - Gradients and valuesa are stored in
node.grad
and innode.value
. You can also see the full function withnode.show_tree()
.
Example:
class Mul(Function):
def forward(self, x, y): return x * y
def backward(self, grad, wrt, x, y): return (y, x)[wrt] * grad
class Add(Function):
def forward(self, x, y): return x + y
def backward(self, grad, wrt, x, y): return grad
mul, add = Mul(), Add()
x, y, z = (Node(), Node(), Node())
f = add(mul(x, y), z)
x.value = 2
y.value = 3
z.value = 1
f.forward()
f.backward(1)
out = f.value
grads = x.grad, y.grad, z.grad, f.grad
See also the examples.
See retrodiff.utils.nn
.