In computer science and mathematics, constructing a new representation of an old concept can kickstart new fields. One of my favorite examples is the connection between graphs and matrices: representing one as the other has proven to be extremely fruitful for both objects. Translating graph theory to linear algebra and vice versa was the key to numerous hard problems.
What does this have to do with machine learning? If you have ever seen a neural network, you know the answer.
From a purely mathematical standpoint, a neural network is a composition of a sequence of functions, say,
whatever those mystical Softmax, Linear, and ReLU functions might be.
On the other hand, take a look at this figure below. Chances are, you have already seen something like this.
This is the computational graph representation of the neural network N. While the expression N(x) = Softmax(Linear₁(Relu(Linear₂(x)))) describes N on the macro level, the figure above zooms into the micro domain.
Computational graphs are more than just visualization tools; they provide us with a way to manage the complexity of large models like a deep neural network.
To give you a taste of how badly we need this, try calculating the derivative of N(x) by hand. Doing this on paper is daunting enough, let alone provide an efficient implementation. Computational graphs solve this problem brilliantly, and in the process, they make machine learning computationally feasible.
Here are the what, why, and how of computational graphs.
What is a computational graph?
Let's go back to square one and take a look at the expression c(a + b). Although this merely seems like a three-variable function defined by
we can decompose it further, down to its computational atoms. Unraveling the formula, we see that f is the composition of the two functions f₁(a, b) = a + b and
f₂(a, b) = a b via
Even though this is a simple example, the expression f₂(c, f₁(a, b)) is already not the nicest to look at. You can imagine how quickly the complexity slips out of control; just try to re-write (a + b)(e(c + d) + f) in the same manner. So, we need a more expressive notation.
What if, instead of using an algebraic expression, we build a graph where the components a, b, c, f₁, and f₂ are nodes, and the connections represent the inputs to each component?
It's easier to show than tell, so this is what I am talking about.
In other words, we have a tree graph, where
the leaf nodes (such as a, b, and c) are the input variables,
and the inner nodes (such as f₁ and f₂) are computations, with the children nodes as the inputs.
Because this bears a resemblance to how brain cells communicate, these nodes are called neurons. Hence the term neural network, which is just a fancy name for computational graphs.
Now, the graph above is just a symbolic representation for the function f(a, b, c) = c(a + b). How do we perform the actual computations? We start from the inputs at the leaf nodes, and work our way up step by step.
Here's an illustration using our recurring example c (a + b) with the inputs a = 1, b = 3, and c = 2. Upon initialization, the computation takes two steps to complete.
This process is called the forward pass, as with each step, we propagate the initial values forward in the computational graph, flowing from the leaves towards the root node. Think of this as a function call. Say, if our computational graph encodes a machine learning model, the forward pass results in a prediction.
Simple as they are, computational graphs provide a tremendous advantage in practice. So, let's create our own framework!
Computational graphs in practice
Mathematically speaking, computational graphs are nothing special. However, things change significantly when we move to the computational realm. Computational graphs provide such an effective framework that training large neural networks would not be feasible without utilizing the clever algorithms available on computational graphs. (You might have heard about backpropagation; we'll study and implement it in the next chapter. Safe to say, backpropagation is one of the pillars of deep learning.)
So, it's time to put what we've learned so far into code. Fasten your seatbelts.
First, we learn how to work with computational graphs, and familiarize us with their interface. Then, we take what we know, and implement our own computational graphs from scratch!
(There are several libraries out there with the purpose of providing a clear and simple implementation of computational graphs, Andrej Karpathy's micrograd and George Hotz's tinygrad are the ones that come to mind. I have took significant inspiration from both. Essentially, we'll build our own micrograd library in the following chapters, one step at a time.)
The code is available in the free and open source mlfz
package, so feel free to play around! You can find the source at https://github.com/the-palindrome/mlfz.
If you want to play around with the code in the post, just clone the mlfz
repository and follow along this Jupyter Notebook.
A computational graph is made out of units of computation called neurons. As each neuron represents a scalar, we'll name the underlying class Scalar
. (In mlfz
, Scalar
s can be found in the mlfz.nn.scalar
module.)
You can think about Scalar
as a number that keeps track the computations that yielded it. We instantiate a one by either setting its value or using operations and functions on other Scalar
s. First, we'll learn how to work with them, then how to implement them.
As each neuron acts as a function, the most straightforward way would be to create a Scalar
class, add a __call__
method to make it callable, then subclass it for each possible function and override the function call. However, this would require a separate class for each component, quickly leading to an uncontrollable class proliferation.
Thus, we'll choose the other way and build computational graphs via supercharging the operations and functions, then dynamically build graphs via applying them. In other words, instead of working with functions, we supercharge the numeric variables to do the work for us.
What properties does a `Scalar` have, and what can it do? Simple. Each one has
a numeric value,
the backwards gradient (whatever that might be),
and the list of incoming edges.
Let's see these on the simple example of a * b
.
We'll talk about the backwards gradient in detail when we discuss backpropagation and the backward pass. For now, its value is 0
, but it'll contain the derivative of the loss function with respect to the node.
The Scalar.prevs
attribute contains a list of Edge
objects, each representing an incoming edge.
In turn, each edge contains
a
Scalar
,and the derivative of the children node, given the parent node.
In our case, the local gradient equals to ∂c/∂a = b, which is -2 in this example.
Essentially, Scalar
is a wrapper over a number. That's why functions like Python's built-in sum
work on them:
Defining computational graphs
As you can see, the Scalar
class is overloaded with features: it dynamically builds the underlying computational graph without you having to worry about it. Let's see the already familiar example of c (a + b).
The expression c * (a + b)
describes a directed graph with nodes a
, b
, a + b
, c
, and c * (a + b)
. We've seen this before:
In our implementation, the computational graph is fully represented by Scalar
-s and Edge
-s. All you need to do is to define functions using operations and functions as building blocks. Scalar
objects are compatible with addition, subtraction, multiplication, division, and exponentiation; that is, the operators +
, -
, *
, /
, and **
.
Besides those, the mlfz.nn.scalar.functional
module contains functions like sin
, cos
, exp
, and log
.
Just to convince you that it works, here's the plot of our sin(x) + sin(y)
function.
Linear regression as a computational graph
We're here to do machine learning, so how about implementing a linear regression model? First, the computational graph of ax + b.
Nothing we haven't seen before; in fact, structurally, it is the same as the graph of c (a + b). The only difference is the computations carried out by the nodes.
So, let's define the linear_regression
function! To avoid ambiguity, we define the parameters a and b of ax + b as a = 3 and b = -1.
We can already glimpse the power of operator overloading: the function linear_regression
is syntactically the same as the vanilla Python version would be, but this time, it can operate on our powerful Scalar
objects. (Well, again, soon-to-be powerful.)
With this, we can even fully trace back the computations.
That's cool, but why would we ever want to do that? It's not clear at this point, but trust me on this, tracing the computations backward is a key tool in calculating derivatives effectively.
But why would we want to compute the derivative? Because we want to fit the model with gradient descent.
The backward pass
To get a node's derivative with respect to all preceding nodes, we use the famous backpropagation algorithm, implemented via the `Scalar.backward` method.
Let's see it in action!
Recall the mysterious Scalar.backwards_grad
attribute? This is where y
's derivatives are stored! Upon calling y.backward()
, the derivatives are calculated with respect to all preceding nodes, and then stored in the backwards_grad
attribute.
Let's see ∂y/∂x!
What about ∂y/∂a?
As the partial derivatives ∂y/∂x, ∂y/∂a for the linear regression y = ax + b are given by
the results indeed seem correct. Yay!
Scalar
is pretty simple to use; that was more or less all about it. In the next part, we'll use computational graphs to train our first machine learning models. See you in the next episode, where we will witness the true power of computational graphs!
P. S. If you are a frequent reader here, you might notice that this post was released about a year ago on The Palindrome. However, since then, I have built an entire neural network library from scratch (which is mlfz, the library used in this post), and I intend to continue the series until tensor networks. I have also remastered the code snippets and illustrations and polished the writing a bit. I hope that you enjoyed it!
Great work! Love the article. Thank you so much
This is a very good explanation. Do you render the equations using manim?