3
3
import jax
4
4
import jax .numpy as jnp
5
5
from flax import core
6
-
7
- # from loss import mse
6
+ from sympy import Float , Integer
8
7
9
8
10
9
def maml_adapt (
11
10
params : core .FrozenDict [str , Any ],
12
11
apply_fn : Callable [[core .FrozenDict [str , Any ], jnp .ndarray ], jnp .ndarray ],
13
12
loss_fn : Callable [[jnp .ndarray , jnp .ndarray ], jnp .ndarray ],
14
13
support_set : Tuple [jnp .ndarray , jnp .ndarray ],
14
+ maml_lr : Float ,
15
+ fas : Integer ,
15
16
) -> core .FrozenDict [str , Any ]:
16
17
17
18
"""Adapts with respect to the support set using the MAML algorithm.
@@ -23,17 +24,15 @@ def maml_adapt(
23
24
apply_fn: A function that applies the model to a batch of data.
24
25
loss_fn: A function that computes the loss of a batch of data.
25
26
support_set: A tuple of (x_train, y_train).
26
-
27
+ maml_lr : Inner learning rate.
28
+ fas: Fast adaption step.
27
29
Returns:
28
30
adapted_params: adapted parameters
29
31
"""
30
32
31
33
theta = params ["params" ]
32
34
mutable_params = [key for key in params if key != "params" ]
33
35
34
- maml_lr = 0.01 # Inner Learning rate. TODO: take this parameter as an argument
35
- fas = 1 # Fast adaptation steps. TODO: take this parameter as an argument
36
-
37
36
def loss (theta , batch ):
38
37
x_train , y_train = batch
39
38
logits , new_mutable_param_values = apply_fn (
@@ -49,29 +48,3 @@ def loss(theta, batch):
49
48
theta = jax .tree_util .tree_map (lambda t , g : t - maml_lr * g , theta , grads )
50
49
51
50
return theta
52
-
53
-
54
- # def maml_init(model: nn.Module, init_key, arr: jnp.ndarray):
55
- # """Initializes the parameters of the model.
56
-
57
- # The default parameters initilized by flax don't convege for
58
- # optimization based meta learning algorithms.
59
- # Hence they are scaled to match a normal distribution with mean 0 and std 0.01.
60
-
61
- # Args:
62
- # model (nn.Module): model whose parameters are to be initialised
63
- # init_key (random.PRNGKey): PRNG Key used for initialisation
64
- # arr (jnp.ndarray): a random array used to initialize the parameters
65
-
66
- # Returns:
67
- # Parameters: A frozen dict of model parameters.
68
- # """
69
-
70
- # EPSILON = 1e-8 # to avoid division by zero
71
- # params = model.init(init_key, arr).unfreeze()
72
- # # Paramters are scaled to match a normal distribution with mean 0 and std 0.01
73
- # params = jax.tree_util.tree_map(
74
- # lambda p: 0.01 * (p - p.mean()) / (p.std() + EPSILON), params
75
- # )
76
- # params = core.frozen_dict.freeze(params)
77
- # return params
0 commit comments