There is a mind-blowing application of matrix multiplication: doing recursion (almost) at the speed of light!
By the end of this post, you'll learn precisely how. Trust me, if you are into programming and math, you want to know this trick.
Let's start with the simplest example for recursion: Fibonacci numbers. Each Fibonacci number is the sum of the previous one and the one before. The recursion starts with and .
In Python, the implementation is rather straightforward.
Can you guess the issue? The execution is extremely slow, as each function call involves two more calls. Thus, the fibonacci(n)
calls itself many times!
If we measure the execution, we find that the time increases exponentially with n
.
Is there any way to improve this? Yes. Spoiler alert: it's all about matrix multiplication.
Matrix multiplication from another perspective
Now, let's talk about matrix multiplication. What do you get when multiplying a row and a column vector? Their inner product.
How does this relate to the Fibonacci numbers? Simple: we can express the Fibonacci recursion in terms of vectors.
With one small trick, we can turn this into an iterative process. By adding a second column to the right side, we can copy the -th Fibonacci number over. Thus, we have a recursive relation:
We can express the above without recursion! Applying our matrix recursion n
times, we obtain an explicit formula to calculate the Fibonacci numbers:
This is much faster to compute.
Just one more step!
By noticing that the Fibonacci numbers start with , we can stack two shifted vectors on top of itself and obtain the -th, -th, -th Fibonacci numbers purely by raising the right matrix to the -th power:
By clearing everything up, we obtain an extremely elegant formula:
Tell me it's not beautiful. I dare you.
We can quickly implement the formula in NumPy.
Let's see how it performs!
As you can see, it is much faster than the vanilla version. Moreover, while the vanilla implementation has exponential time complexity, this has linear. Quite the difference!
A small caveat, though. Due to integer overflow, NumPy is not suitable for this task.
Homework for you: write a plain Python implementation that takes advantage of the arbitrarily large ints!