Skip to content
Discussion options

You must be logged in to vote

For params we use shape inference to check that the initialiser and it's value have the same shape. This avoids a lot of issues with hyper paramaters and params being out of sync for example after restoring a checkpoint. We might at some point at a keyword arg to disable this check but for now an easy workaround is:

class A(linen.Module):
    size: int
        
    def setup(self):
        self.array = self.variable('params', 'array', jnp.zeros, self.size)
        
    def __call__(self):
        return self.array.mean()

btw I would consider putting such a variable in a separate collection than "params" anyway. Quite often you need to enforce shape invariance outside of the model as well …

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@PgLoLo
Comment options

Comment options

You must be logged in to vote
1 reply
@PgLoLo
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants