Why Momentum Really Works

OptimumSolutionStarting Point
Step-size α = 0.02 00.0030.006
Momentum β = 0.99 0.000.5000.990
We often think of Momentum as a means of dampening oscillations and speeding up the iterations, leading to faster convergence. But it has other interesting behavior. It allows a larger range of step-sizes to be used, and creates its own oscillations. What is going on?

Here’s a popular story about momentum [1, 2, 3]: gradient descent is a man walking down a hill. He follows the steepest path downwards; his progress is slow, but steady. Momentum is a heavy ball rolling down the same hill. The added inertia acts both as a smoother and an accelerator, dampening oscillations and causing us to barrel through narrow valleys, small humps and local minima.

This standard story isn’t wrong, but it fails to explain many important behaviors of momentum. In fact, momentum can be understood far more precisely if we study it on the right model.

One nice model is the convex quadratic. This model is rich enough to reproduce momentum’s local dynamics in real problems, and yet simple enough to be understood in closed form. This balance gives us powerful traction for understanding this algorithm.


We begin with gradient descent. The algorithm has many virtues, but speed is not one of them. It is simple — when optimizing a smooth function ff, we make a small step in the gradient wk+1=wkαf(wk).w^{k+1} = w^k-\alpha\nabla f(w^k). For a step-size small enough, gradient descent makes a monotonic improvement at every iteration. It always converges, albeit to a local minimum. And under a few weak curvature conditions it can even get there at an exponential rate.

But the exponential decrease, though appealing in theory, can often be infuriatingly small. Things often begin quite well — with an impressive, almost immediate decrease in the loss. But as the iterations progress, things start to slow down. You start to get a nagging feeling you’re not making as much progress as you should be. What has gone wrong?

The problem could be the optimizer’s old nemesis, pathological curvature. Pathological curvature is, simply put, regions of ff which aren’t scaled properly. The landscapes are often described as valleys, trenches, canals and ravines. The iterates either jump between valleys, or approach the optimum in small, timid steps. Progress along certain directions grind to a halt. In these unfortunate regions, gradient descent fumbles.

Momentum proposes the following tweak to gradient descent. We give gradient descent a short-term memory: zk+1=βzk+f(wk)wk+1=wkαzk+1 \begin{aligned} z^{k+1}&=\beta z^{k}+\nabla f(w^{k})\\[0.4em] w^{k+1}&=w^{k}-\alpha z^{k+1} \end{aligned} The change is innocent, and costs almost nothing. When β=0\beta = 0 , we recover gradient descent. But for β=0.99\beta = 0.99 (sometimes 0.9990.999, if things are really bad), this appears to be the boost we need. Our iterations regain that speed and boldness it lost, speeding to the optimum with a renewed energy.

Optimizers call this minor miracle “acceleration”.

The new algorithm may seem at first glance like a cheap hack. A simple trick to get around gradient descent’s more aberrant behavior — a smoother for oscillations between steep canyons. But the truth, if anything, is the other way round. It is gradient descent which is the hack. First, momentum gives up to a quadratic speedup on many functions. 1 This is no small matter — this is similar to the speedup you get from the Fast Fourier Transform, Quicksort, and Grover’s Algorithm. When the universe gives you quadratic speedups, you should start to pay attention.

But there’s more. A lower bound, courtesy of Nesterov [5], states that momentum is, in a certain very narrow and technical sense, optimal. Now, this doesn’t mean it is the best algorithm for all functions in all circumstances. But it does satisfy some curiously beautiful mathematical properties which scratch a very human itch for perfection and closure. But more on that later. Let’s say this for now — momentum is an algorithm for the book.


First Steps: Gradient Descent

We begin by studying gradient descent on the simplest model possible which isn’t trivial — the convex quadratic, f(w)=12wTAwbTw,wRn. f(w) = \tfrac{1}{2}w^TAw - b^Tw, \qquad w \in \mathbf{R}^n. Assume AA is symmetric and invertible, then the optimal solution ww^{\star} occurs at w=A1b. w^{\star} = A^{-1}b. Simple as this model may be, it is rich enough to approximate many functions (think of AA as your favorite model of curvature — the Hessian, Fisher Information Matrix [6], etc) and captures all the key features of pathological curvature. And more importantly, we can write an exact closed formula for gradient descent on this function.

This is how it goes. Since f(w)=Awb\nabla f(w)=Aw - b, the iterates are wk+1=wkα(Awkb). w^{k+1}=w^{k}- \alpha (Aw^{k} - b). Here’s the trick. There is a very natural space to view gradient descent where all the dimensions act independently — the eigenvectors of AA.

Optimum
Optimum

Every symmetric matrix AA has an eigenvalue decomposition A=Q diag(λ1,,λn) QT,Q=[q1,,qn], A=Q\ \text{diag}(\lambda_{1},\ldots,\lambda_{n})\ Q^{T},\qquad Q = [q_1,\ldots,q_n], and, as per convention, we will assume that the λi\lambda_i’s are sorted, from smallest λ1\lambda_1 to biggest λn\lambda_n. If we perform a change of basis, xk=QT(wkw)x^{k} = Q^T(w^{k} - w^\star), the iterations break apart, becoming: xik+1=xikαλixik=(1αλi)xik=(1αλi)k+1xi0 \begin{aligned} x_{i}^{k+1} & =x_{i}^{k}-\alpha \lambda_ix_{i}^{k} \\[0.4em] &= (1-\alpha\lambda_i)x^k_i=(1-\alpha \lambda_i)^{k+1}x^0_i \end{aligned} Moving back to our original space ww, we can see that wkw=Qxk=inxi0(1αλi)kqi w^k - w^\star = Qx^k=\sum_i^n x^0_i(1-\alpha\lambda_i)^k q_i and there we have it — gradient descent in closed form.

Decomposing the Error

The above equation admits a simple interpretation. Each element of x0x^0 is the component of the error in the initial guess in the QQ-basis. There are nn such errors, and each of these errors follows its own, solitary path to the minimum, decreasing exponentially with a compounding rate of 1αλi1-\alpha\lambda_i. The closer that number is to 11, the slower it converges.

For most step-sizes, the eigenvectors with largest eigenvalues converge the fastest. This triggers an explosion of progress in the first few iterations, before things slow down as the smaller eigenvectors’ struggles are revealed. By writing the contributions of each eigenspace’s error to the loss f(wk)f(w)=(1αλi)2kλi[xi0]2 f(w^{k})-f(w^{\star})=\sum(1-\alpha\lambda_{i})^{2k}\lambda_{i}[x_{i}^{0}]^2 we can visualize the contributions of each error component to the loss.

Optimization can be seen as combination of several component problems, shown here as 1 2 3 with eigenvalues λ1=0.01\lambda_1=0.01, λ2=0.1\lambda_2=0.1, and λ3=1\lambda_3=1 respectively.
Step-size
Optimal Step-size
012