The Palindrome

The Palindrome

Visualizing the Chain Rule

(because I can never remember the formula)

Tivadar Danka's avatar
Tivadar Danka
Dec 22, 2025
∙ Paid

The chain rule is one of the most important formulas of machine learning, allowing us to compute the gradient of composite functions. (Which is needed to do gradient descent and fit the model to the data.)

In a single variable, the chain rule says that

holds for any two univariate differentiable functions f and g. It’s simple, compact, and user-friendly.

However, the math quickly becomes complicated when we introduce multiple variables. And we have to, because there are billions of parameters to tune.

In the general case, we are interested in the composition of the vector-vector function g: ℝⁿ → ℝᵐ with the vector-scalar function f: ℝᵐ → ℝ. For instance, if we have

  • a machine learning model with n parameters x = (x₁, ..., xₙ),

  • a training dataset of m samples d₁, ..., dₘ,

  • predictions g(x) = (g₁(x, d₁), ..., gₘ(w, dₘ)),

  • and a loss function f,

then fitting the model is equivalent to finding the maximum of the composite function h(x) = f(g(x)). (The dependence on the data in g(x) is omitted for simplicity.)

To compute the gradient, that is,

(where ∂ᵢh denotes the partial derivative of the function h with respect to its i-th variable), the chain rule says that

This post is about what’s behind this formula.

User's avatar

Continue reading this post for free, courtesy of Tivadar Danka.

Or purchase a paid subscription.
© 2025 Tivadar Danka · Privacy ∙ Terms ∙ Collection notice
Start your SubstackGet the app
Substack is the home for great culture