Neural networks are stunningly powerful. This is old news: deep learning is state-of-the-art in many fields, like computer vision and natural language processing.
Why this is the case? Why are neural networks so effective?
As always in machine learning, there is a precise mathematical explanation. Simply put, the family of functions described by a neural network model is extremely expressive and able to closely approximate (almost) any function.
But how can a family of functions be expressive? What does it mean to approximate “any” function, and how can we quantify how good of a job neural networks do?
These concepts seem vague and difficult at first glance. The goal of this post is to shed light on these, leading us to understand the surprising effectiveness of neural networks.
This post will eventually make its way into my Mathematics of Machine Learning book, currently out in early access.
If you would like to receive my deep-dive explainers on machine learning and its mathematics, consider getting early access or supporting me with a premium subscription to The Palindrome!
Machine learning as function approximation
Let's formulate the classical supervised learning task from an abstract viewpoint. Suppose that we have our dataset
where xₖ is a data point and yₖ is the corresponding ground truth. The observation yₖ can be a
categorical variable,
a numerical variable,
a probability distribution (in the case of classification),
or a vector in general.
The task is simply to find a function g(x) for which
g(xₖ) is approximately yₖ,
and g(x) is computationally feasible.
To achieve this, we fix a parametrized family of functions in advance and select a parameter configuration that has the best fit for our data. This is our model.
For instance, linear regression uses the function family
as a parametric family of functions, with a and b as parameters.
If we assume that there exists a true underlying function f(x) that describes the relationship between xₖ and yₖ, the problem can be phrased as a function approximation problem: “How can we find a function from our parametric family that is as close to f(x) as possible?“
This leads us into the beautiful albeit very technical field of approximation theory.
A primer on approximation theory
I am (almost) sure that you have encountered the sine function several times already. It is defined in geometric terms, as shown below.
Sine is a transcendental function, meaning that you cannot calculate its value with finite additions and multiplications.
However, when you punch, say, sin(2.123) into a calculator, you'll get an answer. This is only an approximation, although it is often sufficient for our purposes. In fact, we have
which is a polynomial, so its value can be explicitly computed. The larger n is, the closer the approximation is to the actual value.
Here it is in the case of n = 2, which is a polynomial of degree five. It is already a very good approximation, albeit only on the interval [-2, 2].
The central problem of approximation theory is to provide a mathematical framework for these problems. If you have any function f(x) and a family of functions that are computationally easier to handle, your goal is to find a "simple" function close enough to f. In essence, approximation theory searches for answers to three core questions.
What is "close enough"?
Which family of functions can (or should) you use to approximate?
From a given approximating function family, which exact function is the one that will fit the best?
These sound abstract, so we look into a special case: neural networks.
Keep reading with a 7-day free trial
Subscribe to The Palindrome to keep reading this post and get 7 days of free access to the full post archives.