Why does gradient descent work?
Young man, in mathematics you don’t understand things.
You just get used to them. — John von Neumann
In machine learning, we use gradient descent so much that we get used to it. We hardly ever question why it works.
What's usually told is the mountain-climbing analogue: to find the peak (or the bottom) of a bumpy terrain, one has to look at the direction of the steepest ascent (or descent) and take a step towards there. This direction is described by the gradient, and the iterative process of finding local extrema by following the gradient is called gradient ascent/descent. (Ascent for finding peaks, descent for finding valleys.)
However, this is not a mathematically precise explanation. Several questions are left unanswered, and based on our mountain-climbing intuition, it's not even clear if the algorithm works.
Without a precise understanding of gradient descent, we are practically flying blind. In this post, our goal is to develop a proper mathematical framework that will help you understand what's going on behind the scenes. Thus, allowing you to effectively reason about gradient descent, and possibly improve its performance in your projects. Our journey is leading us through
- differentiation as the rate of change,
- the principles of optimization with first and second derivatives,
- the basics of differential equations,
- and how gradient descent is equivalent to physical systems flowing towards their equilibrium state.
Buckle up! Deep dive into the beautiful world of dynamical systems incoming.
Derivatives and their meaning
First things first: if you are not familiar with gradient descent, I wrote a detailed introductory article about it, which I recommend you check before moving on.
If you are comfortable with the gradient descent algorithm, we can start to peek behind the curtain. Let's talk about derivatives first! By definition, we say that a function is differentiable at if the limit
exists. is called the derivative. Although this definition seems random, there is a deep meaning behind it: the derivative can be thought of as the rate of change.
Derivative as speed
Let's jump back in time a few hundred years! At its inception, derivatives were created to describe the speed of moving objects. Suppose that the position of our object at time is given by the function , and for simplicity, assume that it is moving along a straight line—something like this below.
Our goal is to calculate the object's speed at a given time. In high school, we learned that
To put this into a quantitative form, if are two arbitrary points in time, then
Expressions like are called differential quotients. Note that if the object moves backwards, the average speed is negative.
The average speed has a simple geometric interpretation. If you replace the object's motion with a constant velocity motion moving at the average speed, you'll end up at the same place. In graphical terms, this is equivalent of connecting and with a single line. The average speed is just the slope of this line.
Given this, we can calculate the exact speed at a single time point . The idea is simple: the average speed in the small time-interval between and should get closer and closer to the speed if is small enough. ( can be negative as well.)
Geometrically, you can visualize the derivative as the tangent line at , and the differential quotients as slopes of the lines connecting the function's graph between and .
Local extrema and the derivative
Derivatives can tell us a lot more than speed. We've just seen that the derivative equals the slope of the tangent line. With this in mind, take a look at the figure below!
Notice that the tangent is perfectly horizontal at the local minima and maxima. In mathematical terms, this is equivalent to . Can we use this property to find minima and maxima?
Yes, but it's not that simple. Take a look at . Its derivative is zero at , but that is not a local minimum nor a local maximum. Without going into the fine details, the best we can do is to compute the second derivative and hope that it'll give some clarity. To be precise, the following theorem holds.
Theorem. Let be an arbitrary function that is twice differentiable at some .
(a) If and , then is a local minimum.
(b) If and , then is a local maximum.
Why is this important to us? Well, because machine learning is (mostly) just a colossal optimization problem. Construct a parametric model, and find a set of parameters that maximize its performance. As you know, this is done via the famous gradient descent algorithm, which seemingly has nothing to do with our simple result regarding the second derivatives. However, quite the opposite is true: deep down at its core, this stems from the relation of local extrema and the first two derivatives.
In the following, we are going to learn why.
Differential equations 101
What is a differential equation?
Equations play an essential role in mathematics. This is common wisdom, but there is a profound truth behind it. Quite often, equations arise from modeling systems such as interactions in a biochemical network, economic processes, and thousands more. For instance, modeling the metabolic processes in organisms yields linear equations of the form
where the vectors and represent the concentration of molecules (where is the unknown), and the matrix represents the interactions between them. Linear equations are easy to solve, and we understand quite a lot about them.
However, the equations we have seen so far are unfit to model dynamical systems, as they lack a time component. To describe, for example, the trajectory of a space station orbiting around Earth, we have to describe our models in terms of functions and their derivatives.
For instance, the trajectory of a swinging pendulum can be described by the equation
- describes the angle of the pendulum from the vertical,
- is the length of the (massless) rod that our object of mass hangs on,
- and is the gravitational acceleration constant .
According to the original interpretation of differentiation, if describes the movement of the pendulum at time , then and describe the velocity and the acceleration of it, where the differentiation is taken with respect to the time .
(In fact, the differential equation is a direct consequence of Newton's second law of motion.)
Equations involving functions and their derivatives, such as the equation of the swinging pendulum above, are called ordinary differential equations, or ODEs in short. Without any overexaggeration, their study has been the primary motivating force of mathematics since the 17th century. Trust me when I say this, differential equations are one of the most beautiful objects in mathematics. As we are about to see, the gradient descent algorithm is, in fact, an approximate solution of differential equations.
The first part of this post will serve as a quickstart to differential equations. I will mostly follow the fantastic Nonlinear Dynamics and Chaos book by Steven Strogatz. If you ever desire to dig deep into dynamical systems, I wholeheartedly recommend this book to you. (This is one of my favorite math books ever. It reads like a novel. The quality and clarity of its exposition serve as a continuous inspiration for my writing.)
The (slightly more) general form of ODEs
Let's dive straight into the deep waters and start with an example to get a grip on differential equations. Quite possibly, the simplest example is the equation
where the differentiation is taken with respect to the time variable . If, for example, is the size of a bacterial colony, the equation describes its population dynamics if the growth is unlimited. Think about as the rate at which the population grows: if there are no limitations in space and nutrients, every bacterial cell can freely replicate whenever possible. Thus, since every cell can freely divide, the speed of growth matches the colony's size.
In plain English, the solutions of the equation are functions whose derivatives are themselves. After a bit of thinking, we can come up with a family of solutions: , where is an arbitrary constant. (Recall that is an elementary function, and its derivative is itself.)
Some of the solutions are plotted below.
There are two key takeaways here: differential equations describe dynamical processes that change in time, and they can have multiple solutions. Each solution is determined by two factors: the equation itself , and an initial condition . If we specify , then the value of is given by
Thus, ODEs have a bundle of solutions, each one determined by the initial condition. So, it's time to discuss differential equations in more general terms!
Definition. (Ordinary differential equations in one dimension.) Let be a differentiable function. The equation
is called a first-order homogeneous ordinary differential equation.
When it is clear, the dependence on is often omitted, so we only write . (Some resources denote the time derivative by , a notation that can be originated from Newton. We will not use this, though it is good to know.)
The term "first-order homogeneous ordinary differential equation" doesn't exactly roll off the tongue, and it is overloaded with heavy terminology. So, let's unpack what is going on here.
The differential equation part is clear: it is a functional equation that involves derivatives. Since the time is the only variable, the differential equation is ordinary. (As opposed to differential equations involving multivariable functions and partial derivatives, but more on those later.) As only the first derivative is present, the equation becomes first-order. Second-order would involve second derivatives, and so on. Finally, since the right-hand side doesn't explicitly depend on the time variable , the equation is homogeneous in time. Homogeneity means that the rules governing our dynamical system don't change over time.
Don't let the part scare you! For instance, in our example , the role of is cast to the identity function . In general, establishes a relation between the quantity (which can be position, density, etc) and its derivative, that is, its rate of change.
As we have seen, we think in terms of differential equations and initial conditions that pinpoint solutions among a bundle of functions. Let's put this into a proper mathematical definition!
Definition. (Initial value problems.) Let be a first order homogeneous ordinary differential equation and let be an arbitrary value. The system
is called an initial value problem. If a function satisfies both conditions, it is said to be a solution to the initial value problem.
Most often, we select to be . After all, we have the freedom to select the origin of the time as we want.
Unfortunately, things are not as simple as they seem. In general, differential equations and initial value problems are tough to solve. Except for a few simple ones, we cannot find exact solutions. (And when I say we, I include every person on the planet.) In these cases, there are two things that we can do: either we construct approximate solutions via numeric methods or turn to qualitative methods that study the behavior of the solutions without actually finding them.
We'll talk about both, but let's turn to the qualitative methods first. As we'll see, looking from a geometric perspective gives us a deep insight into how differential equations work.
A geometric interpretation of differential equations
When finding analytic solutions is not feasible, we look for a qualitative understanding of the solutions, focusing on the local and long-term behavior instead of formulas.
Imagine that given a differential equation
you are interested in a particular solution that assumes the value at time . For instance, you could be studying the dynamics of a bacterial colony and want to provide a predictive model to fit your latest measurement . In the short term, where will your solutions go?
We can immediately notice that if and , then the constant function is a solution! These are called equilibrium solutions, and they are extremely important. So, let's make a formal definition!
Definition. (Equilibrium solutions.). Let
be a first order homogeneous ODE, and let be an arbitrary point. If , then is called an equilibrium point of the equation .
For equilibrium points, the constant function is a solution of . This is called an equilibrium solution.
Think about our recurring example, the simplest ODE . As mentioned, we can interpret this equation as a model of unrestricted population growth under ideal conditions. In that case, , and this is zero only for . Therefore, the constant function is a solution. This makes perfect sense: if a population has zero individuals, no change is going to happen in its size. In other words, the system is in equilibrium.
Like a pendulum that stopped moving and reached its resting point at the bottom. However, pendulums have two equilibria: one at the top and one at the bottom. (Let's suppose that the mass is held by a massless rod. Otherwise, it would collapse) At the bottom, you can push the hanging mass all you want, it'll return to rest. However, at the top, any small push would disrupt the equilibrium state, to which it would never return.
To shed light on this phenomenon, let's look at another example: the famous logistic equation
From a population dynamics perspective, if our favorite equation describes the unrestricted growth of a bacterial colony, the logistic equation models the population growth under a resource constraint. If we assume that is the total capacity of our population, the growth becomes more difficult as the size approaches this limit. Thus, the population's rate of change can be modelled as , where the term slows down the process as the colony nears the sustain capacity.
We can write the logistic equation in the general form by casting the role . Do you recall the relation of derivatives and monotonicity ? Translated to the differential equation , this reveals the flow of our solutions! To be specific,
We can visualize this in the so-called phase portrait.
Thus, the monotonicity describes long-term behavior:
With a little bit of calculation (whose details are not essential for us), we can obtain that we can write the solutions as
where is an arbitrary constant. For , this is the famous Sigmoid function. You can check by hand that these are indeed solutions. We can even plot them, as shown below.
As we can see, the monotonicity of the solutions is as we predicted.
We can characterize the equilibria based on the long-term behavior of nearby solutions. (In the case of our logistic equation, the equilibria are and .) This can be connected to the local behavior of : if it decreases around the equilibrium , it attracts the nearby solutions. On the other hand, if increases around , the nearby solutions are repelled.
This gives rise to the concept of stable and unstable equilibria.
Definition. (Stable and unstable equilibria.) Let be a first-order homogeneous ordinary differential equation, and suppose that is differentiable. Moreover, let be an equilibrium of the equation.
is called a stable equilibrium if there is a neighborhood around such that for all , the solution of the initial value problem
converges towards . (That is, holds.) If is not stable, it is called unstable.
In the case of the logistic ODE , is a stable and is an unstable equilibrium. This makes sense given its population dynamics interpretation: the equilibrium means that the population is at maximum capacity. If the size is slightly above or below the capacity , some specimens die due to starvation, or the colony reaches its constraints. On the other hand, no matter how small the population is, it won't ever go extinct in this ideal model.
A continuous version of gradient ascent
Now, let's talk about maximizing a function . Suppose that is twice differentiable, and we denote its derivative by . Luckily, the local maxima of can be found with the help of its second derivative by looking for where and .
Does this look familiar? If indeed holds, then is an equilibrium solution; and since , it attracts the nearby solutions as well. This means that if is drawn from the basin of attraction and is the solution of the initial value problem
then . In other words, the solution converges towards , a local maxima of ! This is gradient ascent in a continuous version.
We are happy, but there is an issue. We've talked about how hard solving differential equations are. For a general , we have no prospects to actually find the solutions. Fortunately, we can approximate them.
Gradient ascent as a discretized differential equation
When doing differentiation in practice, derivatives are often approximated numerically by the forward difference
If is indeed the solution for the corresponding initial value problem, we are in luck! Using forward differences, we can take a small step from and approximate by substituting the forward difference into the differential equation. To be precise, we have
follows. By defining and by
we have . If this looks like the first step of the gradient ascent to you, you are on the right track. Using the forward difference once again, this time from the point , we obtain
thus by defining , we have . Notice that in , two kinds of approximation errors are accumulated: first the forward difference, then the approximation error of the previous step.
This motivates us to define the recursive sequence
which approximates with , as this is implied by the very definition. This recursive sequence is the gradient ascent itself, and the small step is the learning rate! In the context of differential equations, this is called the Euler method.
Without going into the details, if is small enough and "behaves properly", the Euler method will converge to the equilibrium solution . (Whatever proper behavior might mean.)
We only have one more step: to turn everything into gradient descent instead of ascent. This is extremely simple, as gradient descent is just applying gradient ascent to . Think about it: minimizing a function is the same as maximizing its negative . And with that, we are done! The famous gradient descent is a consequence of dynamical systems converging towards their stable equilibria, and this is beautiful.
The gradient ascent in action
To see the gradient ascent (that is, the Euler method) in action, we should go back to our good old example: the logistic equation . So, suppose that we want to find the local maxima of the function
which is plotted below.
First, we can use what we learned and find the maxima using the derivative , concluding that there is a local maximum at . (Don't just take my word, pick up a pencil and work it out!)
Since and , the point is a stable equilibrium of the logistic equation
Thus, if the initial value is sufficiently close to , the solution of the initial value problem
then . (In fact, we can select any initial value from the infinite interval , and the convergence will hold.) Upon discretization via the Euler method, we obtain the recursive sequence
This process is visualized below.
Solving via the Euler-method. (For visualization purposes, the initial value was set at .)
We can even take the discrete solution provided by the Euler method and plot it on the - plane.
To sum up what we've seen so far, our entire goal was to understand the very principles of gradient descent, the most important optimization algorithm in machine learning. Its main principle is straightforward: to find a local minimum of a function, first find the direction of decrease, then take a small step towards there. This seemingly naive algorithm has a foundation that lies deep within differential equations. Turns out that if we look at our functions as rules determining a dynamical system, local extrema correspond to equilibrium states. These dynamical systems are described by differential equations, and the local maxima are equilibrium states that solutions towards them. From this viewpoint, the gradient descent algorithm is nothing else than a numerical solution to this equation.
What we've seen so far only covers the single-variable case, and as I have probably told this many times, machine learning is done in millions of dimensions. Still, the intuition we built up will be our guide in the study of multivariable functions and high-dimensional spaces. There, the principles are the same, but the objects of study are much more complex. The main challenge in multivariable calculus is to manage the complexity, and this is where our good friends, vectors and matrices will do much of the heavy lifting. But that's for another day!