Skip to content

Commit e55bf83

Browse files
authored
Merge pull request #39 from AdityaD16/updated_maml
Added fas and maml_lr as arguments.
2 parents 1a6e5ae + 2904105 commit e55bf83

File tree

1 file changed

+5
-32
lines changed

1 file changed

+5
-32
lines changed

jeta/maml.py

Lines changed: 5 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33
import jax
44
import jax.numpy as jnp
55
from flax import core
6-
7-
# from loss import mse
6+
from sympy import Float, Integer
87

98

109
def maml_adapt(
1110
params: core.FrozenDict[str, Any],
1211
apply_fn: Callable[[core.FrozenDict[str, Any], jnp.ndarray], jnp.ndarray],
1312
loss_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray],
1413
support_set: Tuple[jnp.ndarray, jnp.ndarray],
14+
maml_lr: Float,
15+
fas: Integer,
1516
) -> core.FrozenDict[str, Any]:
1617

1718
"""Adapts with respect to the support set using the MAML algorithm.
@@ -23,17 +24,15 @@ def maml_adapt(
2324
apply_fn: A function that applies the model to a batch of data.
2425
loss_fn: A function that computes the loss of a batch of data.
2526
support_set: A tuple of (x_train, y_train).
26-
27+
maml_lr : Inner learning rate.
28+
fas: Fast adaption step.
2729
Returns:
2830
adapted_params: adapted parameters
2931
"""
3032

3133
theta = params["params"]
3234
mutable_params = [key for key in params if key != "params"]
3335

34-
maml_lr = 0.01 # Inner Learning rate. TODO: take this parameter as an argument
35-
fas = 1 # Fast adaptation steps. TODO: take this parameter as an argument
36-
3736
def loss(theta, batch):
3837
x_train, y_train = batch
3938
logits, new_mutable_param_values = apply_fn(
@@ -49,29 +48,3 @@ def loss(theta, batch):
4948
theta = jax.tree_util.tree_map(lambda t, g: t - maml_lr * g, theta, grads)
5049

5150
return theta
52-
53-
54-
# def maml_init(model: nn.Module, init_key, arr: jnp.ndarray):
55-
# """Initializes the parameters of the model.
56-
57-
# The default parameters initilized by flax don't convege for
58-
# optimization based meta learning algorithms.
59-
# Hence they are scaled to match a normal distribution with mean 0 and std 0.01.
60-
61-
# Args:
62-
# model (nn.Module): model whose parameters are to be initialised
63-
# init_key (random.PRNGKey): PRNG Key used for initialisation
64-
# arr (jnp.ndarray): a random array used to initialize the parameters
65-
66-
# Returns:
67-
# Parameters: A frozen dict of model parameters.
68-
# """
69-
70-
# EPSILON = 1e-8 # to avoid division by zero
71-
# params = model.init(init_key, arr).unfreeze()
72-
# # Paramters are scaled to match a normal distribution with mean 0 and std 0.01
73-
# params = jax.tree_util.tree_map(
74-
# lambda p: 0.01 * (p - p.mean()) / (p.std() + EPSILON), params
75-
# )
76-
# params = core.frozen_dict.freeze(params)
77-
# return params

0 commit comments

Comments
 (0)