How to compress a neural network

Tivadar Danka small portrait Tivadar Danka
Compressing a neural network.

Modern state-of-the-art neural network architectures are HUGE. For instance, you have probably heard about GPT-3, OpenAI's newest revolutionary NLP model, capable of writing poetry and interactive storytelling.

Well, GPT-3 has around 175 billion parameters.

To give you a perspective about how large this number is, consider the following. A $100 bill is approximately 6.14 inches wide. If you start laying down the bills next to each other, the line will stretch 169,586 miles. For comparison, Earth's circumference is 24,901 miles, measured along the equator. So, it would take ~6.8 round trips until we ran out of money.

Unfortunately, as opposed to money, more is sometimes not better when it comes to the number of parameters. Sure, more parameters seem to mean better results, but also more massive costs. According to the original paper, GPT-3 required 3.14E+23 flops of training time, and the computing cost itself is in the millions of dollars.

GPT-3 is so large that it cannot be easily moved to other machines. It is currently accessible through the OpenAI API, so you can't just clone a GitHub repository and run it on your computer.

However, this is just the tip of the iceberg. Deploying much smaller models can also present a significant challenge for machine learning engineers. In practice, small and fast models are much better than cumbersome ones.

Because of this, researchers and engineers have put significant energy into compressing models. Out of these efforts, several methods have emerged to deal with the problem.

The why and the how

If we revisit GPT-3 for a minute, we can see how the number of parameters and the training time influence performance.

Validation loss vs. compute time in different variants of the GPT-3 model. Colors represent the number of parameters. Source: Language Models are Few-Shot Learners by Tom B. Brown et al. (https://arxiv.org/pdf/2005.14165.pdf) Validation loss vs. compute time in different variants of the GPT-3 model. Colors represent the number of parameters. Source: Language Models are Few-Shot Learners by Tom B. Brown et al.

The trend seems clear: more parameters lead to better performance and higher computational costs. The latter not only impacts the training time but the server costs and the environmental effects as well. (Training large models can emit more CO2 than a car in its entire lifetime.) However, training is only the first part of the life cycle of a neural network. In the long run, inference costs take over.

Three main methods have emerged for optimizing these costs by compressing the models:

  • weight pruning,
  • quantization,
  • knowledge distillation.

In this post, my goal is to introduce you to these and overview how they work.

Let's get started!

Weight pruning

One of the oldest methods for reducing a neural network's size is weight pruning, eliminating specific connections between neurons. In practice, elimination means that the removed weight is replaced with zero.

At first glance, this idea might be surprising. Wouldn't this eliminate the knowledge learned by the neural network?

Sure, removing all of the connections would undoubtedly result in losing all that is learned. On the other part of the spectrum, pruning only one connection probably wouldn't mean any decrease in accuracy.

The question is, how much can you remove until the predictive performance starts to suffer?

Optimal Brain Damage

The first ones to study this question were Yann LeCun, John S. Denker, and Sara A. Solla. In their paper Optimal Brain Damage from 1990. They have developed the following iterative method.

  1. Train a network.
  2. Estimate the importance of each weight by watching how the loss would change upon perturbing the weight. Smaller change means less importance. (This importance is called saliency.)
  3. Remove the weights with low importance.
  4. Go back to Step 1. and retrain the network, permanently fixing the removed weights to zero.

During their experiments with pruning the LeNet for MNIST classification, they could remove a significant portion of the weights without a noticeable increase in the loss.

Source: Optimal Brain Damage by Yann LeCun, John S. Denker and Sara A. Solla (https://proceedings.neurips.cc/paper/1989/file/6c9882bbac1c7093bd25041881277658-Paper.pdf) Source: Optimal Brain Damage by Yann LeCun, John S. Denker and Sara A. Solla

However, retraining was necessary after the pruning. This proved to be quite tricky since a smaller model means a smaller capacity. Besides, as mentioned above, training amounts to a significant portion of the computational costs, and this kind of compression only helps in inference time.

Is there a method requiring less post-pruning training but still reaching the unpruned model's predictive performance?

Lottery Ticket Hypothesis

One essential breakthrough was made in 2008 by researchers from MIT. In their paper titled The Lottery Ticket Hypothesis, Jonathan Frankle and Michael Carbin stated that in their hypothesis that

A randomly-initialized, dense neural network contains a subnetwork that is initialized such that - when trained in isolation - it can match the test accuracy of the original network after training for at most the same number of iterations.

Such subnetworks are called winning lottery tickets. To see why let's consider that you buy 10100010^{1000} lottery tickets. (This is more than the number of atoms in the observable universe, but we'll let this one slide.) Because you have so many, there is a tiny probability that none of them are winners. This is similar to training a neural network, where we randomly initialize weights.

If this hypothesis is true, and we can find such subnetworks, training could be done much faster and cheaper since a single iteration step would take less computation.

The question is, does the hypothesis hold, and if so, how can we find such subnetworks? The authors proposed the following iterative method.

  1. Randomly initialize the network and store the initial weights for later reference.
  2. Train the network for a given number of steps.
  3. Remove a percentage of the weights with the lowest magnitude.
  4. Restore the remaining weights to the value that was given during the first initialization.
  5. Go to Step 2. and iterate the pruning.

This method offered a significant improvement on simple architectures trained on simple datasets, such as LeNet on MNIST, as shown in the figure below.

Source: The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks by Jonathan Frankle and Michael Carbin (https://arxiv.org/pdf/1803.03635.pdf) Source: The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks by Jonathan Frankle and Michael Carbin

Although the algorithm showed promise, it did not perform well on more complex architectures like ResNets. Moreover, pruning still happens after training, which is a significant problem.

SynFlow

One of the most recent algorithms to prune before training was published in 2020. In their paper, Hidenori Tanaka, Daniel Kunin, Daniel L. K. Yamins, and Surya Ganguli from Stanford developed a method that goes much further and does the pruning without training.

First, they introduce the concept of layer collapse,

the premature pruning of an entire layer making a network untrainable, which plays a significant part in the theory. Any pruning algorithm should avoid layer collapse. The hard part is identifying a class of algorithms that satisfies this criterion.

For this purpose, the authors introduce the synaptic saliency score for a given weight in the network defined by

S(w)=wLw,S(w) = w \frac{\partial L}{\partial w},

where L\textstyle L is the loss function given by the network's output, and w\textstyle w is a weight parameter. Each neuron conserves this quantity: under certain constraints for the activation functions, the sum of incoming synaptic salience scores is equal to the sum of outgoing synaptic saliency scores.

This score is used to select which weights are pruned. (Recall that the Optimal Brain Damage method used a perturbation-based quantity for this purpose, while the authors of the Lottery Ticket Hypothesis paper used the magnitude.)

It turns out that synaptic saliency scores are conserved between layers, and roughly speaking, if an iterative pruning algorithm respects this layer-wise conservation, layer collapse can be avoided.

The SynFlow algorithm is an iterative pruning algorithm similar to the previous ones, but the selection is based on the synaptic saliency scores.

Source: Pruning neural networks without any data by iteratively conserving synaptic flow by Hidenori Tanaka, Daniel Kunin, Daniel L. K. Yamins, and Surya Ganguli (https://arxiv.org/pdf/2006.05467.pdf) Source: Pruning neural networks without any data by iteratively conserving synaptic flow by Hidenori Tanaka, Daniel Kunin, Daniel L. K. Yamins, and Surya Ganguli

However, the work is far from done. As Jonathan Frankle and co-authors point out in their recent paper, there is no universal state-of-the-art solution. Each method shines in specific scenarios but outperformed in others. Moreover, the pre-training pruning methods outperform the baseline random pruning. However, they still don't perform as well as some post-training algorithms, especially magnitude-based pruning.

Implementations

Pruning is available both in TensorFlow and PyTorch.

Quantization

A neural network is essentially just a bunch of linear algebra and some other operations. By default, most systems use float32 types to represent the variables and weights.

However, in general, computations in other formats such as int8 can be faster than float32, with less memory footprint. (Of course, these can depend on the hardware, but we are not trying to be extra specific here.)

Neural network quantization is the suite of methods aiming to take advantage of this. For instance, if we would like to go from float32 to int8 as mentioned, and our values are in the range [a,a][-a, a] for some real number a\textstyle a, we could use the transformation

x128xa x \mapsto \Big\lfloor 128 \frac{x}{a} \Big\rfloor

to convert the weights and proceed with the computations in the new form.

Of course, things are not that simple. Multiplying two int8 numbers can easily overflow to int16, and so on. During quantization, we must take care to avoid errors caused by this.

As with all compression methods, this comes with a loss of information and possibly predictive performance. The problem is the same as before: to find an optimal trade-off.

Quantization has two primary flavors: post-training quantization and quantization-aware training. The former is more straightforward but can result in a more significant accuracy loss than the latter.

Quantization methods and their performance in TensorFlow Lite. Source: TensorFlow Lite documentation (https://www.tensorflow.org/lite/performance/model_optimization) Quantization methods and their performance in TensorFlow Lite. Source: TensorFlow Lite documentation

As you can see in the table above, this can cut the inference time in half in some instances. However, converting from float32 to int8 is not a smooth transformation; thus, it can lead to suboptimal results when the gradient landscape is wild.

With quantization-aware training, this method has the potential to improve training time as well.

Implementations

Similarly to weight pruning, quantization is also available both in TensorFlow and PyTorch.

At the time of the writing, the feature is experimental in PyTorch, which is subject to change. So, you should expect breaking changes in the upcoming versions.

So far, the methods we have seen share the same principle: train the network and discard some information to compress it. As we will see, the third one, knowledge distillation, differs from these significantly.

Knowledge distillation

Although quantization and pruning can be effective, they are destructive in the end. An alternative approach was developed by Geoffrey Hinton, Oriol Vinyals, and Jeff Dean in their paper Distilling the Knowledge in a Neural Network.

Their idea is simple: train a big model (teacher) to achieve top performance and use its predictions to train a smaller one (student). Knowledge distillation. Knowledge distillation. Their work showed that large ensemble models could be compressed with simpler architectures that are more suitable for production.

Knowledge distillation improves the inference time of the distilled models, not the training time. This is an essential distinction between the other two methods since training often has a high cost. (If we think back to the GPT-3 example, it was millions of dollars.)

You might ask, why not just use a compact architecture from the start? The secret sauce is to teach the student model to generalize like the teacher by using its predictions. Here, the student model sees the training data for the big one and the new data as well, where it is fitted to approximate the teacher's output.

Knowledge distillation results on a speech recognition problem from the paper Distilling the Knowledge in a Neural Network by Geoffrey Hinton, Oriol Vinyals, and Jeff Dean (https://arxiv.org/pdf/1503.02531.pdf) Knowledge distillation results on a speech recognition problem from the paper Distilling the Knowledge in a Neural Network by Geoffrey Hinton, Oriol Vinyals, and Jeff Dean The smaller a model is, the more training data it needs to generalize well. Thus, it might require a complex architecture such as an ensemble model to achieve state-of-the-art performance on challenging tasks. Still, its knowledge can be used to push the student model's performance beyond the baseline.

One of the first use cases for knowledge distillation was compressing ensembles and making them suitable for production. Ensembles were notorious in Kaggle competitions. Several winning models were composed of several smaller ones, offering outstanding results but being unusable in practice.

Since then, it has been applied successfully for other architectures, most notably BERT, the famous transformer model for NLP.

Besides the baseline distillation approach by Hinton et al., several other ones are trying to push state-of-the-art. If you want to get an overview of those, I recommend this survey paper.

Implementations

Since knowledge distillation does not require manipulating weights like pruning or quantization, it can be performed in any framework of your choice. Here are some examples to get you started!

Conclusion

As neural networks are getting larger and larger, compression of the models is becoming even more critical. As the complexity of the problems and architectures increases, so does the computational cost and the environmental impact.

This trend only seems to accelerate: GPT-3 contains 175 billion parameters, a 10x jump in the magnitude compared to previous giant models. Thus, compressing these networks is a fundamental problem, which will become even more critical in the future.

Are you ready to tackle this challenge?

Understanding math is a superpower in machine learning.

I am writing a book about it to help you go from high school mathematics to neural networks.
Join me on this journey and let's do this together!