Hooks: the one PyTorch trick you must know
If you have ever used deep learning before, you know that debugging a model can be really hard sometimes. Tensor shape mismatches, exploding gradients, and countless other issues can surprise you. Solving these requires looking at the model under the microscope. The most basic methods include littering the
forward() methods with print statements or introducing breakpoints. These are not scalable because they require guessing, and are pretty tedious to do overall.
However, there is a solution: hooks. These are specific functions, able to be attached to every layer and called each time the layer is used. They allow you to freeze the execution of the forward or backward pass at a specific module and process its inputs and outputs.
Let’s see them in action!
Hooks crash course
So, a hook is just a callable object with a predefined signature, which can be registered to any
nn.Module object. When the trigger method is used on the module (i.e.,
backward()), the module itself with its inputs and possible outputs are passed to the hook, executing before the computation proceeds to the next module.
In PyTorch, you can register a hook as a
- forward prehook (executing before the forward pass),
- forward hook (executing after the forward pass),
- backward hook (executing after the backward pass).
It might sound complicated at first, so let’s take a look at a concrete example!
##An example: saving the outputs of each convolutional layer
Suppose that we want to inspect the output of each convolutional layer in a ResNet34 architecture. This task is perfectly suitable for hooks. In the next part, I will show you how this can be performed. If you want to follow it interactively, you can find the accompanying Jupyter notebook here.
Our model is defined by the following.
import torch from torchvision.models import resnet34 device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') model = resnet34(pretrained=True) model = model.to(device)
Creating a hook to save outputs is very simple; a basic callable object is perfectly enough for our purposes.
class SaveOutput: def __init__(self): self.outputs =  def __call__(self, module, module_in, module_out): self.outputs.append(module_out) def clear(self): self.outputs = 
An instance of
SaveOutput will simply record the output tensor of the forward pass and stores it in a list.
We can register a forward hook with the
register_forward_hook method. (For the other types of hooks, we have
register_forward_pre_hook.) The return value of these methods is the hook handle, which we can use to remove the hook from the module.
Now we register the hook to each convolutional layer.
save_output = SaveOutput() hook_handles =  for layer in model.modules(): if isinstance(layer, torch.nn.modules.conv.Conv2d): handle = layer.register_forward_hook(save_output) hook_handles.append(handle)
When this is done, the hook will be called after each forward pass of each convolutional layer. To test it out, we are going to use the following image.
The forward pass:
from PIL import Image from torchvision import transforms as T image = Image.open('cat.jpg') transform = T.Compose([T.Resize((224, 224)), T.ToTensor()]) X = transform(image).unsqueeze(dim=0).to(device) out = model(X)
As expected, the outputs were appropriately stored.
>>> len(save_output.outputs) 36
By inspecting the tensors in this list, we can visualize what the network sees.
Just for curiosity, we can check what happens later. If we go deeper into the network, the learned features become more and more high-level. For instance, a filter seems to be responsible for detecting the eyes.
Of course, this is just the tip of the iceberg. Hooks can do much more than simply store outputs of intermediate layers. For instance, neural network pruning, a technique to reduce the number of parameters, can also be performed with hooks.
To summarize, applying hooks is a handy technique to learn if you want to enhance your workflow. With this under your belt, you’ll be able to do much more and do them more effectively.