Why does gradient descent work?

Tivadar Danka small portrait Tivadar Danka
Solutions of the logistic differential equation

Understanding math will make you a better engineer.

So, I am writing the best and most comprehensive book about it.

Young man, in mathematics you don’t understand things.
You just get used to them. — John von Neumann

In machine learning, we use gradient descent so much that we get used to it. We hardly ever question why it works.

What's usually told is the mountain-climbing analogue: to find the peak (or the bottom) of a bumpy terrain, one has to look at the direction of the steepest ascent (or descent) and take a step towards there. This direction is described by the gradient, and the iterative process of finding local extrema by following the gradient is called gradient ascent/descent. (Ascent for finding peaks, descent for finding valleys.)

However, this is not a mathematically precise explanation. Several questions are left unanswered, and based on our mountain-climbing intuition, it's not even clear if the algorithm works.

Without a precise understanding of gradient descent, we are practically flying blind. In this post, our goal is to develop a proper mathematical framework that will help you understand what's going on behind the scenes. Thus, allowing you to effectively reason about gradient descent, and possibly improve its performance in your projects. Our journey is leading us through

  • differentiation as the rate of change,
  • the principles of optimization with first and second derivatives,
  • the basics of differential equations,
  • and how gradient descent is equivalent to physical systems flowing towards their equilibrium state.

Buckle up! Deep dive into the beautiful world of dynamical systems incoming.

Derivatives and their meaning

First things first: if you are not familiar with gradient descent, I wrote a detailed introductory article about it, which I recommend you check before moving on.

If you are comfortable with the gradient descent algorithm, we can start to peek behind the curtain. Let's talk about derivatives first! By definition, we say that a function f f is differentiable at x0 x_0 if the limit

f(x0):=limh0f(x0+h)f(x0)h f^\prime(x_0) := \lim_{h \to 0} \frac{f(x_0 + h) - f(x_0)}{h}

exists. f(x0) f^\prime(x_0) is called the derivative. Although this definition seems random, there is a deep meaning behind it: the derivative can be thought of as the rate of change.

Derivative as speed

Let's jump back in time a few hundred years! At its inception, derivatives were created to describe the speed of moving objects. Suppose that the position of our object at time t t is given by the function x(t) x(t) , and for simplicity, assume that it is moving along a straight line—something like this below.

Time-distance plot of a moving object

Our goal is to calculate the object's speed at a given time. In high school, we learned that

average speed=distancetime. \text{average speed} = \frac{\text{distance}}{\text{time}}.

To put this into a quantitative form, if t0<t1 t_0 < t_1 are two arbitrary points in time, then

average speed between t0 and t1=x(t1)x(t0)t1t0. \text{average speed between } t_0 \text{ and } t_1 = \frac{x(t_1) - x(t_0)}{t_1 - t_0}.

Expressions like x(t1)x(t0)t1t0 \frac{x(t_1) - x(t_0)}{t_1 - t_0} are called differential quotients. Note that if the object moves backwards, the average speed is negative.

The average speed has a simple geometric interpretation. If you replace the object's motion with a constant velocity motion moving at the average speed, you'll end up at the same place. In graphical terms, this is equivalent of connecting (t0,x(t0)) (t_0, x(t_0)) and (t1,x(t1)) (t_1, x(t_1)) with a single line. The average speed is just the slope of this line.

Average speed during a time interval

Given this, we can calculate the exact speed at a single time point t0 t_0 . The idea is simple: the average speed in the small time-interval between t0 t_0 and t0+Δt t_0 + \Delta t should get closer and closer to the speed if Δt \Delta t is small enough. (Δt \Delta t can be negative as well.)

Thus,

speed =x(t0)=limΔt0x(t0+Δt)x(t0)Δt. \text{speed } = x^\prime(t_0) = \lim_{\Delta t \to 0} \frac{x(t_0 + \Delta t) - x(t_0)}{\Delta t}.

Geometrically, you can visualize the derivative as the tangent line at t0 t_0 , and the differential quotients as slopes of the lines connecting the function's graph between t0 t_0 and t0+Δt t_0 + \Delta t .

Derivative as average speeds taken on a smaller and smaller interval

Local extrema and the derivative

Derivatives can tell us a lot more than speed. We've just seen that the derivative equals the slope of the tangent line. With this in mind, take a look at the figure below!

Local maxima and the derivative

Notice that the tangent is perfectly horizontal at the local minima and maxima. In mathematical terms, this is equivalent to x(t)=0 x^\prime(t) = 0 . Can we use this property to find minima and maxima?

Yes, but it's not that simple. Take a look at x(t)=t3 x(t) = t^3 . Its derivative is zero at 0 0 , but that is not a local minimum nor a local maximum. Without going into the fine details, the best we can do is to compute the second derivative and hope that it'll give some clarity. To be precise, the following theorem holds.

Theorem. Let f:RR f: \mathbb{R} \to \mathbb{R} be an arbitrary function that is twice differentiable at some aR a \in \mathbb{R} .

(a) If f(a)=0 f^\prime(a) = 0 and f(a)>0 f^{\prime \prime}(a) > 0 , then a a is a local minimum.
(b) If f(a)=0 f^\prime(a) = 0 and f(a)<0 f^{\prime \prime}(a) < 0 , then a a is a local maximum.

Why is this important to us? Well, because machine learning is (mostly) just a colossal optimization problem. Construct a parametric model, and find a set of parameters that maximize its performance. As you know, this is done via the famous gradient descent algorithm, which seemingly has nothing to do with our simple result regarding the second derivatives. However, quite the opposite is true: deep down at its core, this stems from the relation of local extrema and the first two derivatives.

In the following, we are going to learn why.

Differential equations 101

What is a differential equation?

Equations play an essential role in mathematics. This is common wisdom, but there is a profound truth behind it. Quite often, equations arise from modeling systems such as interactions in a biochemical network, economic processes, and thousands more. For instance, modeling the metabolic processes in organisms yields linear equations of the form

Ax=b,ARn×n,x,bRn Ax = b, \quad A \in \mathbb{R}^{n \times n}, \quad x, b \in \mathbb{R}^n

where the vectors x x and b b represent the concentration of molecules (where x x is the unknown), and the matrix A A represents the interactions between them. Linear equations are easy to solve, and we understand quite a lot about them.

However, the equations we have seen so far are unfit to model dynamical systems, as they lack a time component. To describe, for example, the trajectory of a space station orbiting around Earth, we have to describe our models in terms of functions and their derivatives.

For instance, the trajectory of a swinging pendulum can be described by the equation

x(t)+gLsinx(t)=0, x^{\prime \prime}(t) + \frac{g}{L} \sin x(t) = 0,

where

  • x(t) x(t) describes the angle of the pendulum from the vertical,
  • L L is the length of the (massless) rod that our object of mass m m hangs on,
  • and g g is the gravitational acceleration constant 9.7m/s2 \approx 9.7 m/s^2 .

According to the original interpretation of differentiation, if x(t) x(t) describes the movement of the pendulum at time t t , then x(t) x^\prime(t) and x(t) x^{\prime \prime}(t) describe the velocity and the acceleration of it, where the differentiation is taken with respect to the time t t .

(In fact, the differential equation x(t)+gLsinx(t)=0x^{\prime \prime}(t) + \frac{g}{L} \sin x(t) = 0 is a direct consequence of Newton's second law of motion.)

A swinging pendulum

Equations involving functions and their derivatives, such as the equation of the swinging pendulum above, are called ordinary differential equations, or ODEs in short. Without any overexaggeration, their study has been the primary motivating force of mathematics since the 17th century. Trust me when I say this, differential equations are one of the most beautiful objects in mathematics. As we are about to see, the gradient descent algorithm is, in fact, an approximate solution of differential equations.

The first part of this post will serve as a quickstart to differential equations. I will mostly follow the fantastic Nonlinear Dynamics and Chaos book by Steven Strogatz. If you ever desire to dig deep into dynamical systems, I wholeheartedly recommend this book to you. (This is one of my favorite math books ever. It reads like a novel. The quality and clarity of its exposition serve as a continuous inspiration for my writing.)

The (slightly more) general form of ODEs

Let's dive straight into the deep waters and start with an example to get a grip on differential equations. Quite possibly, the simplest example is the equation

x(t)=x(t), x^\prime(t) = x(t),

where the differentiation is taken with respect to the time variable t t . If, for example, x(t) x(t) is the size of a bacterial colony, the equation x(t)=x(t) x^\prime(t) = x(t) describes its population dynamics if the growth is unlimited. Think about x(t) x^\prime(t) as the rate at which the population grows: if there are no limitations in space and nutrients, every bacterial cell can freely replicate whenever possible. Thus, since every cell can freely divide, the speed of growth matches the colony's size.

In plain English, the solutions of the equation x(t)=x(t) x^\prime(t) = x(t) are functions whose derivatives are themselves. After a bit of thinking, we can come up with a family of solutions: x(t)=cet x(t) = c e^t , where cR c \in \mathbb{R} is an arbitrary constant. (Recall that et e^t is an elementary function, and its derivative is itself.)

Some of the solutions are plotted below.

The exponential growth

There are two key takeaways here: differential equations describe dynamical processes that change in time, and they can have multiple solutions. Each solution is determined by two factors: the equation itself x(t)=x(t) x^\prime(t) = x(t) , and an initial condition x(0)=x x(0) = x^* . If we specify x(0)=x x(0) = x^* , then the value of c c is given by

x(0)=ce0=c=x. x(0) = c e^0 = c = x^*.

Thus, ODEs have a bundle of solutions, each one determined by the initial condition. So, it's time to discuss differential equations in more general terms!

Definition. (Ordinary differential equations in one dimension.) Let f:RR f: \mathbb{R} \to \mathbb{R} be a differentiable function. The equation

x(t)=f(x(t)) x^\prime(t) = f(x(t))

is called a first-order homogeneous ordinary differential equation.

When it is clear, the dependence on t t is often omitted, so we only write x=f(x) x^\prime = f(x) . (Some resources denote the time derivative by x˙\dot x , a notation that can be originated from Newton. We will not use this, though it is good to know.)

The term "first-order homogeneous ordinary differential equation" doesn't exactly roll off the tongue, and it is overloaded with heavy terminology. So, let's unpack what is going on here.

The differential equation part is clear: it is a functional equation that involves derivatives. Since the time t t is the only variable, the differential equation is ordinary. (As opposed to differential equations involving multivariable functions and partial derivatives, but more on those later.) As only the first derivative is present, the equation becomes first-order. Second-order would involve second derivatives, and so on. Finally, since the right-hand side f(x) f(x) doesn't explicitly depend on the time variable t t , the equation is homogeneous in time. Homogeneity means that the rules governing our dynamical system don't change over time.

Don't let the f(x(t)) f(x(t)) part scare you! For instance, in our example x(t)=x(t) x^\prime(t) = x(t) , the role of f f is cast to the identity function f(x)=x f(x) = x . In general, f(x) f(x) establishes a relation between the quantity x(t) x(t) (which can be position, density, etc) and its derivative, that is, its rate of change.

As we have seen, we think in terms of differential equations and initial conditions that pinpoint solutions among a bundle of functions. Let's put this into a proper mathematical definition!

Definition. (Initial value problems.) Let x=f(x) x^\prime = f(x) be a first order homogeneous ordinary differential equation and let x0R x_0 \in \mathbb{R} be an arbitrary value. The system

{x=f(x)x(t0)=x0 \begin{cases} x^\prime &= f(x) \\ x(t_0) &= x_0 \end{cases}

is called an initial value problem. If a function x(t) x(t) satisfies both conditions, it is said to be a solution to the initial value problem.

Most often, we select t0 t_0 to be 0 0 . After all, we have the freedom to select the origin of the time as we want.

Unfortunately, things are not as simple as they seem. In general, differential equations and initial value problems are tough to solve. Except for a few simple ones, we cannot find exact solutions. (And when I say we, I include every person on the planet.) In these cases, there are two things that we can do: either we construct approximate solutions via numeric methods or turn to qualitative methods that study the behavior of the solutions without actually finding them.

We'll talk about both, but let's turn to the qualitative methods first. As we'll see, looking from a geometric perspective gives us a deep insight into how differential equations work.

A geometric interpretation of differential equations

When finding analytic solutions is not feasible, we look for a qualitative understanding of the solutions, focusing on the local and long-term behavior instead of formulas.

Imagine that given a differential equation

x(t)=f(x(t)), x^\prime(t) = f(x(t)),

you are interested in a particular solution that assumes the value x x^* at time t0 t_0 . For instance, you could be studying the dynamics of a bacterial colony and want to provide a predictive model to fit your latest measurement x(t0)=x x(t_0) = x^* . In the short term, where will your solutions go?

We can immediately notice that if x(t0)=x x(t_0) = x^* and f(x)=0 f(x^*) = 0 , then the constant function x(t)=x x(t) = x^* is a solution! These are called equilibrium solutions, and they are extremely important. So, let's make a formal definition!

Definition. (Equilibrium solutions.). Let

x(t)=f(x(t)) x^\prime(t) = f(x(t))

be a first order homogeneous ODE, and let xR x^\ast \in \mathbb{R} be an arbitrary point. If f(x)=0 f(x^\ast) = 0 , then x x^\ast is called an equilibrium point of the equation x=f(x) x^\prime = f(x) .

For equilibrium points, the constant function x(t)=x x(t) = x^\ast is a solution of x=f(x)x^\prime = f(x). This is called an equilibrium solution.

Think about our recurring example, the simplest ODE x(t)=x(t) x^\prime(t) = x(t) . As mentioned, we can interpret this equation as a model of unrestricted population growth under ideal conditions. In that case, f(x)=x f(x) = x , and this is zero only for x=0 x = 0 . Therefore, the constant x(t)=0 x(t) = 0 function is a solution. This makes perfect sense: if a population has zero individuals, no change is going to happen in its size. In other words, the system is in equilibrium.

Like a pendulum that stopped moving and reached its resting point at the bottom. However, pendulums have two equilibria: one at the top and one at the bottom. (Let's suppose that the mass is held by a massless rod. Otherwise, it would collapse) At the bottom, you can push the hanging mass all you want, it'll return to rest. However, at the top, any small push would disrupt the equilibrium state, to which it would never return.

To shed light on this phenomenon, let's look at another example: the famous logistic equation

x(t)=x(t)(1x(t)). x^\prime(t) = x(t) ( 1 - x(t) ).

From a population dynamics perspective, if our favorite equation x(t)=x(t) x^\prime(t) = x(t) describes the unrestricted growth of a bacterial colony, the logistic equation models the population growth under a resource constraint. If we assume that 1 1 is the total capacity of our population, the growth becomes more difficult as the size approaches this limit. Thus, the population's rate of change x(t) x^\prime(t) can be modelled as x(t)(1x(t)) x(t) ( 1 - x(t)) , where the term 1x(t) 1 - x(t) slows down the process as the colony nears the sustain capacity.

We can write the logistic equation in the general form x=f(x) x^\prime = f(x) by casting the role f(x)=x(1x) f(x) = x (1 - x) . Do you recall the relation of derivatives and monotonicity ? Translated to the differential equation x=f(x) x^\prime = f(x) , this reveals the flow of our solutions! To be specific,

x(t) is {increasingif f(x)>0,decreasingif f(x)<0,constantif f(x)=0. x(t) \text{ is } \begin{cases} \text{increasing} & \text{if } f(x) > 0, \\ \text{decreasing} & \text{if } f(x) < 0, \\ \text{constant} & \text{if } f(x) = 0. \end{cases}

We can visualize this in the so-called phase portrait.

Phase portrait of the logistic equation

Thus, the monotonicity describes long-term behavior:

limtx(t)={1if x(0)>0,0if x(0)=0,if x(0)<0. \lim_{t \to \infty} x(t) = \begin{cases} 1 & \text{if } x^\prime(0) > 0, \\ 0 & \text{if } x^\prime(0) = 0, \\ -\infty & \text{if } x^\prime(0) < 0. \end{cases}

With a little bit of calculation (whose details are not essential for us), we can obtain that we can write the solutions as

x(t)=11+cet, x(t) = \frac{1}{1 + c e^{-t}},

where cR c \in \mathbb{R} is an arbitrary constant. For c=1 c = 1 , this is the famous Sigmoid function. You can check by hand that these are indeed solutions. We can even plot them, as shown below.

Solution of the logistic equation

As we can see, the monotonicity of the solutions is as we predicted.

We can characterize the equilibria based on the long-term behavior of nearby solutions. (In the case of our logistic equation, the equilibria are 0 0 and 1 1 .) This can be connected to the local behavior of f f : if it decreases around the equilibrium x x^* , it attracts the nearby solutions. On the other hand, if f f increases around x x^* , the nearby solutions are repelled.

This gives rise to the concept of stable and unstable equilibria.

Definition. (Stable and unstable equilibria.) Let x=f(x) x^\prime = f(x) be a first-order homogeneous ordinary differential equation, and suppose that f f is differentiable. Moreover, let x x^\ast be an equilibrium of the equation.

x x^\ast is called a stable equilibrium if there is a neighborhood (xε,x+ε) (x^\ast - \varepsilon, x^\ast + \varepsilon) around x x^\ast such that for all x0(xε,x+ε) x_0 \in (x^\ast - \varepsilon, x^\ast + \varepsilon) , the solution of the initial value problem

{x=f(x)x(0)=x0 \begin{cases} x^\prime &= f(x) \\ x(0) &= x_0 \end{cases}

converges towards x x^\ast . (That is, limtx(t)=x \lim_{t \to \infty} x(t) = x^\ast holds.) If x x^\ast is not stable, it is called unstable.

In the case of the logistic ODE x=x(1x) x^\prime = x(1 - x) , x=1 x^* = 1 is a stable and x=0 x^* = 0 is an unstable equilibrium. This makes sense given its population dynamics interpretation: the equilibrium x=1 x^* = 1 means that the population is at maximum capacity. If the size is slightly above or below the capacity 1 1 , some specimens die due to starvation, or the colony reaches its constraints. On the other hand, no matter how small the population is, it won't ever go extinct in this ideal model.

A continuous version of gradient ascent

Now, let's talk about maximizing a function F:RR F: \mathbb{R} \to \mathbb{R} . Suppose that F F is twice differentiable, and we denote its derivative by F=f F^\prime = f . Luckily, the local maxima of F F can be found with the help of its second derivative by looking for x x^* where f(x)=0 f(x^*) = 0 and f(x)<0 f^\prime(x^*) < 0 .

Does this look familiar? If f(x)=0 f(x^*) = 0 indeed holds, then x(t)=x x(t) = x^* is an equilibrium solution; and since f(x)<0 f^\prime(x^*) < 0 , it attracts the nearby solutions as well. This means that if x0 x_0 is drawn from the basin of attraction and x(t) x(t) is the solution of the initial value problem

{x=f(x)x(0)=x0, \begin{cases} x^\prime &= f(x) \\ x(0) &= x_0, \end{cases}

then limtx(t)=x \lim_{t \to \infty} x(t) = x^* . In other words, the solution converges towards x x^* , a local maxima of F F ! This is gradient ascent in a continuous version.

We are happy, but there is an issue. We've talked about how hard solving differential equations are. For a general F F , we have no prospects to actually find the solutions. Fortunately, we can approximate them.

Gradient ascent as a discretized differential equation

When doing differentiation in practice, derivatives are often approximated numerically by the forward difference

x(t)x(t+h)x(t)h. x^\prime(t) \approx \frac{x(t + h) - x(t)}{h}.

If x(t) x(t) is indeed the solution for the corresponding initial value problem, we are in luck! Using forward differences, we can take a small step from 0 0 and approximate x(h) x(h) by substituting the forward difference into the differential equation. To be precise, we have

x(h)x(0)hf(x(0)), \frac{x(h) - x(0)}{h} \approx f(x(0)),

from which

x(h)x(0)+hf(x(0)) x(h) \approx x(0) + h f(x(0))

follows. By defining x0 x_0 and x1 x_1 by

x0:=x(0),x1:=x0+hf(x0), \begin{align*} x_0 &:= x(0), \\ x_1 &:= x_0 + h f(x_0), \end{align*}

we have x1x(h) x_1 \approx x(h) . If this looks like the first step of the gradient ascent to you, you are on the right track. Using the forward difference once again, this time from the point x(h) x(h) , we obtain

x(2h)x(h)+hf(x(h))x1+hf(x1), \begin{align*} x(2h) &\approx x(h) + h f(x(h)) \\ &\approx x_1 + h f(x_1), \end{align*}

thus by defining x2:=x1+hf(x1) x_2 := x_1 + h f(x_1) , we have x2x(2h) x_2 \approx x(2h) . Notice that in x2 x_2 , two kinds of approximation errors are accumulated: first the forward difference, then the approximation error of the previous step.

This motivates us to define the recursive sequence

x0:=x(0),xn+1:=xn+hf(xn), \begin{align*} x_0 &:= x(0), \\ x_{n+1} &:= x_n + h f(x_n), \end{align*}

which approximates x(nh) x(n h) with xn x_n , as this is implied by the very definition. This recursive sequence is the gradient ascent itself, and the small step h h is the learning rate! In the context of differential equations, this is called the Euler method.

Without going into the details, if h h is small enough and f f "behaves properly", the Euler method will converge to the equilibrium solution x x^* . (Whatever proper behavior might mean.)

We only have one more step: to turn everything into gradient descent instead of ascent. This is extremely simple, as gradient descent is just applying gradient ascent to f -f . Think about it: minimizing a function f f is the same as maximizing its negative f -f . And with that, we are done! The famous gradient descent is a consequence of dynamical systems converging towards their stable equilibria, and this is beautiful.

The gradient ascent in action

To see the gradient ascent (that is, the Euler method) in action, we should go back to our good old example: the logistic equation x=x(1x)x^\prime = x(1 - x). So, suppose that we want to find the local maxima of the function

F(x)=12x213x3, F(x) = \frac{1}{2}x^2 - \frac{1}{3}x^3,

which is plotted below.

The logistic equation as an optimization problem

First, we can use what we learned and find the maxima using the derivative f(x)=F(x)=x(1x) f(x) = F^\prime(x) = x(1 - x) , concluding that there is a local maximum at x=1 x^* = 1 . (Don't just take my word, pick up a pencil and work it out!)

Since f(x)=F(x)=0 f(x^*) = F^\prime(x^*) = 0 and f(x)<0 f^\prime(x^*) < 0 , the point x x^* is a stable equilibrium of the logistic equation

x=x(1x). x^\prime = x(1 - x).

Thus, if the initial value x(0)=x0 x(0) = x_0 is sufficiently close to x=1 x^* = 1 , the solution x(t) x(t) of the initial value problem

{x=x(1x),x(0)=x0, \begin{cases} x^\prime &= x(1 - x), \\ x(0) &= x_0, \\ \end{cases}

then limtx(t)=x \lim_{t \to \infty} x(t) = x^* . (In fact, we can select any initial value x0 x_0 from the infinite interval (0,) (0, \infty) , and the convergence will hold.) Upon discretization via the Euler method, we obtain the recursive sequence

x0=x(0),xn+1=xn+hxn(1xn). \begin{align*} x_0 &= x(0), \\ x_{n+1} &= x_n + h x_n (1 - x_n). \end{align*}

This process is visualized below.

The Euler method, illustrated on the logistic equation Solving x=x(1x) x^\prime = x(1 - x) via the Euler-method. (For visualization purposes, the initial value was set at t0=5 t_0 = -5 .)

We can even take the discrete solution provided by the Euler method and plot it on the x x -F(x) F(x) plane.

Gradient ascent in action

Conclusion

To sum up what we've seen so far, our entire goal was to understand the very principles of gradient descent, the most important optimization algorithm in machine learning. Its main principle is straightforward: to find a local minimum of a function, first find the direction of decrease, then take a small step towards there. This seemingly naive algorithm has a foundation that lies deep within differential equations. Turns out that if we look at our functions as rules determining a dynamical system, local extrema correspond to equilibrium states. These dynamical systems are described by differential equations, and the local maxima are equilibrium states that solutions towards them. From this viewpoint, the gradient descent algorithm is nothing else than a numerical solution to this equation.

What we've seen so far only covers the single-variable case, and as I have probably told this many times, machine learning is done in millions of dimensions. Still, the intuition we built up will be our guide in the study of multivariable functions and high-dimensional spaces. There, the principles are the same, but the objects of study are much more complex. The main challenge in multivariable calculus is to manage the complexity, and this is where our good friends, vectors and matrices will do much of the heavy lifting. But that's for another day!

Having a deep understanding of math will make you a better engineer.

I want to help you with this, so I am writing a comprehensive book that takes you from high school math to the advanced stuff.
Join me on this journey and let's do this together!