In this tutorial, we will introduce how to list all trainable variables in pytorch, which is very useful when you are fine-tuning a model.
How to list all trainable variables in pytorch?
It is very easy to implement, here is an example:
vars = system.parameters() for v in vars: if v.requires_grad: print(v.numel(), v.requires_grad)
Here system is a pytorch model, if a pytorch variable is trainable, its requires_grad will be True.
Run this code, you may view:
Moreover, if we want to show the parameter name, we can do as this:
vars = system.named_parameters() for n, k in vars: if k.requires_grad: print(n, k.requires_grad)
Then, we may see: