Skip to content

What is the difference between Module.param() and Module.variable()? #919

Answered by marcvanzee
marcvanzee asked this question in Q&A
Discussion options

You must be logged in to vote

Module.param() is a special case of Module.variable that assumes 3 things:

  1. the variable is immutable, so it returns the parameter directly, rather than a mutable reference to it
  2. the param init() fn accepts a PRNGKey as the first arg (in addition to any additionally provided ones)
  1. the variable collection name is params

In code:

p = self.param('param_name', init_fn, shape, dtype)
# is a convenient shorthand for this:
p = self.variable('params', 'param_name', lambda s, d: init_fn(self.make_rng('params'), s, d), shape, dtype).value

So self.param is simply a a convenience, because 95% or more of the time user want to define NN parameters, but the self.variable call lets you do whatever yo…

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
1 participant
Converted from issue

This discussion was converted from issue #852 on January 21, 2021 14:09.