Skip to content

Commit

Permalink
Replace nn.Moudle by nn.Module (huggingface#12541)
Browse files Browse the repository at this point in the history
  • Loading branch information
SaulLu authored Jul 6, 2021
1 parent f42a0ab commit 09af5bd
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/research_projects/jax-projects/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ class FlaxMLPModel(FlaxMLPPreTrainedModel):

Now the `FlaxMLPModel` will have a similar interface as PyTorch or Tensorflow models and allows us to attach loaded or randomely initialized weights to the model instance.

So the important point to remember is that the `model` is not an instance of `nn.Moudle`; it's an abstract class, like a container that holds a Flax module, its parameters and provides convenient methods for initialization and forward pass. The key take-away here is that an instance of `FlaxMLPModel` is very much stateful now since it holds all the model parameters, whereas the underlying Flax module `FlaxMLPModule` is still stateless. Now to make `FlaxMLPModel` fully compliant with JAX transformations, it is always possible to pass the parameters to `FlaxMLPModel` as well to make it stateless and easier to work with during training. Feel free to take a look at the code to see how exactly this is implemented for ex. [`modeling_flax_bert.py`](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_flax_bert.py#L536)
So the important point to remember is that the `model` is not an instance of `nn.Module`; it's an abstract class, like a container that holds a Flax module, its parameters and provides convenient methods for initialization and forward pass. The key take-away here is that an instance of `FlaxMLPModel` is very much stateful now since it holds all the model parameters, whereas the underlying Flax module `FlaxMLPModule` is still stateless. Now to make `FlaxMLPModel` fully compliant with JAX transformations, it is always possible to pass the parameters to `FlaxMLPModel` as well to make it stateless and easier to work with during training. Feel free to take a look at the code to see how exactly this is implemented for ex. [`modeling_flax_bert.py`](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bert/modeling_flax_bert.py#L536)

Another significant difference between Flax and PyTorch models is that, we can pass the `labels` directly to PyTorch's forward pass to compute the loss, whereas Flax models never accept `labels` as an input argument. In PyTorch, gradient backpropagation is performed by simply calling `.backward()` on the computed loss which makes it very handy for the user to be able to pass the `labels`. In Flax however, gradient backpropagation cannot be done by simply calling `.backward()` on the loss output, but the loss function itself has to be transformed by `jax.grad` or `jax.value_and_grad` to return the gradients of all parameters. This transformation cannot happen under-the-hood when one passes the `labels` to Flax's forward function, so that in Flax, we simply don't allow `labels` to be passed by design and force the user to implement the loss function her-/himself. As a conclusion, you will see that all training-related code is decoupled from the modeling code and always defined in the training scripts themselves.

Expand Down

0 comments on commit 09af5bd

Please sign in to comment.