This project has now been incorporated into GPJax.
"Module pytrees" that cleanly handle parameter trainability and transformations for JAX models.
pip install mytree
- Mark leaf attributes with
param_field
to set a default bijector and trainable status. - Unmarked leaf attributes default to an
Identity
bijector and trainablility set toTrue
.
from mytree import Mytree, param_field, Softplus, Identity
class SimpleModel(Mytree):
weight: float = param_field(bijector=Softplus, trainable=False)
def __init__(self, weight, bias):
self.weight = weight
self.bias = bias # Unmarked π attribute `bias`, has `Identity` bijector and trainability set to `True`.
def __call__(self, test_point):
return test_point * self.weight + self.bias
- Fully compatible with Distrax and TensorFlow Probability bijectors, so feel free to use these!
Works seamlessly with the dataclasses.dataclass
decorators!
from dataclasses import dataclass
@dataclass
class SimpleModel(Mytree):
weight: float = param_field(bijector=Softplus, trainable=False)
bias: float
def __call__(self, test_point):
return test_point * self.weight + self.bias
Update values via replace
.
model = SimpleModel(1.0, 2.0)
model.replace(weight=123.0)
SimpleModel(weight=123.0, bias=2.0)
Use constrain
/ unconstrain
to return a Mytree
with each parameter's bijector forward
/ inverse
operation applied!
model.constrain()
model.unconstrain()
SimpleModel(weight=1.3132616, bias=2.0)
SimpleModel(weight=0.5413248, bias=2.0)
Default transformations can be replaced on an instance via the replace_bijector
method.
new = model.replace_bijector(bias=Identity)
new.constrain()
new.unconstrain()
SimpleModel(weight=1.0, bias=2.0)
SimpleModel(weight=1.0, bias=2.0)
And we see that weight
's parameter is no longer transformed under the Identity
.
Applying stop_gradient
within the loss function, prevents the flow of gradients during forward or reverse-mode automatic differentiation.
import jax
# Create simulated data.
n = 100
key = jax.random.PRNGKey(123)
x = jax.random.uniform(key, (n, ))
y = 3.0 * x + 2.0 + 1e-3 * jax.random.normal(key, (n, ))
# Define a mean-squared-error loss.
def loss(model: SimpleModel) -> float:
model = model.stop_gradient() # π Stop gradients!
return jax.numpy.sum((y - model(x))**2)
jax.grad(loss)(model)
SimpleModel(weight=0.0, bias=-188.37418)
As weight
trainability was set to False
, it's gradient is zero as expected!
Default trainability status can be replaced via the replace_trainable
method.
new = model.replace_trainable(weight=True)
jax.grad(loss)(model)
SimpleModel(weight=-121.42676, bias=-188.37418)
And we see that weight
's gradient is no longer zero.
View field metadata pytree via meta
.
from mytree import meta
meta(model)
SimpleModel(weight=({'bijector': Bijector(forward=<function <lambda> at 0x17a024e50>, inverse=<function <lambda> at 0x17a024430>), 1.0), 'trainable': False, 'pytree_node': True}, bias=({}, 2.0))
Or the metadata pytree leaves via meta_leaves
.
from mytree import meta_leaves
meta_leaves(model)
[({}, 2.0),
({'bijector': Bijector(forward=<function <lambda> at 0x17a024e50>, inverse=<function <lambda> at 0x17a024430>),
'trainable': False,
'pytree_node': True}, 1.0)]
Note this shows any metadata defined via a dataclasses.field
for the pytree leaves. So feel free to define your own.
Leaf metadata can be applied via the meta_map
function.
from mytree import meta_map
# Function passed to `meta_map` has its argument as a `(meta, leaf)` tuple!
def if_trainable_then_10(meta_leaf):
meta, leaf = meta_leaf
if meta.get("trainable", True):
return 10.0
else:
return leaf
meta_map(if_trainable_then_10, model)
SimpleModel(weight=1.0, bias=10.0)
It is possible to define your own custom metadata and therefore your own metadata transformations in this vein.
Since Mytree
inherits from simple-pytree's Pytree
, fields can be marked as static via simple_pytree's static_field
.
import jax.tree_util as jtu
from simple_pytree import static_field
class StaticExample(Mytree):
b: float = static_field
def __init__(self, a=1.0, b=2.0):
self.a=a
self.b=b
jtu.tree_leaves(StaticExample())
[1.0]
Preliminary benchmarks can be found in: https://github.com/Daniel-Dodd/mytree/blob/master/benchmarks/benchmarks.ipynb