Chain Rule
In the machine learning world, we will often deal with functions that are more complex than simple polynomial, exponential, or sinusoidal functions. Most of the time, functions will be composite, meaning that one function will be located inside another function (which might also be located within another function). The functions \(\sin{(x^2)}\), \(\ln{(4x\cos{(x^3)}+x)}\), and \(e^{\cos{(3x^2 + 4)}}\) are all composite functions, and being able to calculate the derivatives of such functions is essential for training neural networks.
This material introduces a simple method for computing derivatives of composite functions: the so-called chain rule.
Basics of the Chain Rule
The chain rule can become unruly from a notational point of view when using the Leibniz notation for the derivative: \(\frac{\mathrm{d}f}{\mathrm{d}x}\). For the moment, let’s adopt a functional notation for the derivative: \(f'(x)\). That is, \(\frac{\mathrm{d}f}{\mathrm{d}x}\) and \(f'(x)\) represent exactly the same function - the derivative of \(f(x)\). Additionally, let’s assume that all of our functions are only single-variable functions, for the time being.
Given the composition of the function \(g(x)\) with the function \(f(x)\)
the chain rule states that the derivative of the composite function with respect to \(x\) is given by the composition of the function \(g'(x)\) with \(f(x)\), multiplied by \(f'(x)\):
Using the \(g \circ f\) notation for function composition, the chain rule says
Example Calculation Using the Chain Rule
Let’s jump to an example immediately to make sure that we are not confused by this notation. Consider the following functions:
The derivatives of \(f(x)\) and \(g(x)\) are quite simple: \begin{align} f'(x) &= 3\\ g'(x) &= 2x\\ \end{align}
According to the chain rule, this is all we need to compute the derivative of \((g\circ f)(x)\). Recognizing that \((g'\circ f)(x) = 2f(x)\), we can write the derivative of \((g\circ f)(x)\) as
Plugging in for \(f(x)\) and \(f'(x)\), we obtain
As an exercise, write \((g \circ f)(x)\) out in full — as \((g \circ f)(x) = (3x + 1)^2 - 2\), expanding the squared term — and take its derivative directly. Verify that the result you obtain agrees with the equation for \((g \circ f)'(x)\) that we arrived at by using the chain rule. Review this example carefully, and be sure to have a clear understanding of the symbolic form of the chain rule.
Representing the Chain Rule Using Leibniz Notation
We will ultimately need to make use of the chain rule generalized to multivariable functions. For this, Leibniz notation is extremely valuable. Recall that we write the partial derivative of \(f(x,y)\) with respect to \(x\) as \(\frac{\partial f}{\partial x}\). Let’s translate the chain rule into Leibniz notation:
Here, \(g(x)\) depends on another dependent variable: \(f(x)\). This is why we use the vertical line to indicate that the derivative of \(g(x)\) is to be evaluated using the value of \(f(x)\) as its input variable. Because we will always evaluate intermediate derivatives within the chain rule in this fashion, we can forego using the vertical line and simply remain mindful of the preceding statement. Thus the chain rule, written using Leibniz notation, is
This is the notation that we will use moving forward, especially as we begin to work with partial derivatives of multivariable functions. This simple chain rule is also sufficient for generalizing to an arbitratily-long sequence of compositions.
Reading Comprehension: Proof of Chain Rule With Multiple Composite Functions
Use the equation \(\frac{\mathrm{d}(g \circ f)}{\mathrm{d}x} = \frac{\mathrm{d}g}{\mathrm{d}f}\frac{\mathrm{d}f}{\mathrm{d}x}\) to prove that
where \(\frac{\mathrm{d}f_j}{\mathrm{d}f_{j+1}}\) is understood to be evaluated at \((f_{j+1} \circ \cdots \circ f_n)(x)\).
Hint: Consider one composition at a time. In other words, what is \(\frac{\mathrm{d}(f_1\circ g)}{\mathrm{d}x}\), where \(g=f_2\circ\cdots\circ f_n\)?
One final note to help clarify the vertical-bar notation used above. If we wanted to compute the derivative of \((g \circ f)\), evaluated at, say, \(x = 2\), we would denote this as
which, of course, is the same as writing
To be clear, \((g \circ f)'(2)\) and \(\frac{\mathrm{d}(g \circ f)}{\mathrm{d}x}\Bigr|_{x=2}\) both mean: take the derivative of \((g \circ f)(x)\) and evaluate the resulting function at \(x = 2\). It doesn’t make sense to take the derivative of \((g \circ f)(2)\), as this is simply a number.
Reading Comprehension: Chain Rule With a Single Variable Function
Calculate the derivative with respect to \(x\) of the function
First, do this using the chain rule. Then do it by expanding out the function and using just the power rule. Confirm that both derivatives are equivalent.
The Chain Rule for Multivariable Functions
The case of composing a single-variable function with a multivariable one is quite simple for extending the chain rule with partial derivatives. Take the single-variable function \(g(x)\) and multivariable function \(f(x,y)\). Then, for \(g(f(x,y))\), \begin{align} \frac{\mathrm{d}g}{\mathrm{d}x} &= \frac{\mathrm{d}g}{\mathrm{d}f}\frac{\partial f}{\partial x} \\ \frac{\mathrm{d}g}{\mathrm{d}y} &= \frac{\mathrm{d}g}{\mathrm{d}f}\frac{\partial f}{\partial y} \end{align}
You will also encounter more complicated instances, in which \(g\) itself depends on multiple functions of the independent variables: \(g(x, y) = g(p(x, y),\, q(x, y))\). The following result is very important. Here, you simply accumulate (i.e. sum) the derivatives that are contributed by \(p\) and \(q\), respectively:
Again, this can be generalized to accommodate an arbitrary number of dependent variables. So, for the function \(g(f_1(x, y), f_2(x, y), ..., f_n(x, y))\), \begin{align} \frac{\mathrm{d} g}{\mathrm{d} x} &= \frac{\partial g}{\partial f_1}\frac{\partial f_1}{\partial x} + \frac{\partial g}{\partial f_2}\frac{\partial f_2}{\partial x} + ... + \frac{\partial g}{\partial f_n}\frac{\partial f_n}{\partial x} \\ \frac{\mathrm{d} g}{\mathrm{d} y} &= \frac{\partial g}{\partial f_1}\frac{\partial f_1}{\partial y} + \frac{\partial g}{\partial f_2}\frac{\partial f_2}{\partial y} + ... + \frac{\partial g}{\partial f_n}\frac{\partial f_n}{\partial y} \\ \end{align}
This should make sense once dissected — we want to describe how varying \(x\) by a small amount affects \(g\). Thus we need to know how varying \(x\) affects \(f_1\) \(\big(\!\) through \(\frac{\partial f_1}{\partial x}\big)\), and multiply it with how varying \(f_1\) affects \(g\) \(\big(\!\) through \(\frac{\partial g}{\partial f_1}\big)\). So \(\frac{\partial g}{\partial f_1}\frac{\partial f_1}{\partial x}\) describes how varying \(x\) affects \(g\) via \(f_1\). Repeat this for \(f_2,\dots,\,f_n\), and sum up all of these contributions to arrive at how varying \(x\) affects \(g\) in total: \(\frac{\mathrm{d} g}{\mathrm{d} x}\)
A Simple Example
Given the following functions, we will calculate \(\frac{\mathrm{d} g}{\mathrm{d} x}\) and \(\frac{\mathrm{d} g}{\mathrm{d} y}\) at the point \((x=3, y=1)\). Take \(g(p(x,y), q(x, y))\) to be given by
According to the chain rule provided above, the derivatives needed to compute \(\frac{\mathrm{d} g}{\mathrm{d} x}\) and \(\frac{\mathrm{d} g}{\mathrm{d} y}\) are simply \begin{align} \frac{\partial g}{\partial p}\bigg|_{x=3, y=1} &= 2p(3, 1) = 2\cdot (1 \cdot 3^2) = 18\\ \frac{\partial g}{\partial q}\bigg|_{x=3, y=1} &= -3q(3, 1)^2 = -3\cdot (2 \cdot 3 + 1)^2 = -3\cdot (49) = -147 \\ \frac{\partial p}{\partial x}\bigg|_{x=3, y=1} &= 2yx\big|_{x=3, y=1} = 2 (1 \cdot 3) = 6\\ \frac{\partial p}{\partial y}\bigg|_{x=3, y=1} &= x^2\big|_{x=3, y=1} = 3^2 = 9\\ \frac{\partial q}{\partial x}\bigg|_{x=3, y=1} &= 2 \\ \frac{\partial q}{\partial y}\bigg|_{x=3, y=1} &= 1 \end{align}
We can simply plug these values into the expression for the chain rule for a function of multiple dependent variables, and we will have computed the derivatives of \(g\) with respect to \(x\) and \(y\) at the given point:
Autodifferentiation and the Chain Rule
Autodifferentiation libraries, like MyGrad, naturally use the chain rule to compute derivatives of composite functions. See that it reproduces the exact same values for derivatives as indicated above.
# Using MyGrad to evaluate the partial derivatives
# of a composite multivariable function
import mygrad as mg
# Initializes x and y as tensors
>>> x = mg.tensor(3)
>>> y = mg.tensor(1)
>>> p = y * x ** 2
>>> q = 2 * x + y
>>> g = p ** 2 - q ** 3
# Computes the derivatives of g with respect to all
# variables that it depends on
>>> g.backward()
>>> p.grad # stores ∂g/∂p @ x=3, y=1
array(18.)
>>> q.grad # stores ∂g/∂q @ x=3, y=1
array(-147.)
>>> x.grad # stores dg/dx @ x=3, y=1
array(-186.)
>>> y.grad # stores dg/dy @ x=3, y=1
array(15.)
Reading Comprehension: Chain Rule With a Multivariable Function
For \(g(p(x,y), q(x,y))\), where
calculate \(\frac{\mathrm{d} f}{\mathrm{d} y}\big|_{x=5, y=4}\).
Reading Comprehension: Chain Rule With a Multivariable Function With MyGrad
Calculate the same partial derivative from the previous question (Chain Rule With a Multivariable Function), but this time, compute it using MyGrad. Verify that this gives the same result as doing the math by hand.
Reading Comprehension Exercise Solutions
Proof of Chain Rule With Multiple Composite Functions: Solution
We are given that \(\frac{\mathrm{d}(g \circ f)}{\mathrm{d}x} = \frac{\mathrm{d}g}{\mathrm{d}f}\frac{\mathrm{d}f}{\mathrm{d}x}\).
Performing one iteration of the chain rule on the given function, we get that
Performing the chain rule on \(\frac{\mathrm{d}f_2}{\mathrm{d}x}\), we find that
Substituting this into the previous equation, we get that
We can keep repeating this process and we will find that
Chain Rule With a Multivariable Function: Solution
The multivariable chain rule states that
We can now find the partial derivatives of \(f\) with respect to \(p\) and \(q\) and evaluate at \((x,y)=(5,4)\):
Finally, the partial derivatives of \(p\) and \(q\) with respect to \(y\) can be found and evaluated as
Therefore
Chain Rule With a Multivariable Function With MyGrad: Solution
# Using MyGrad to evaluate the partial derivatives of a multivariable function
import mygrad as mg
# Initializes x and y as tensors
>>> x = mg.tensor(5)
>>> y = mg.tensor(4)
>>> p = 4 * y - x
>>> q = 3 * x * y - 4 * y
>>> g = p ** 2 * q
# Computes the derivatives of g with respect to all
# variables that it depends on
>>> g.backward()
>>> p.grad # stores ∂g/∂p @ x=5, y=4
array(968.)
>>> q.grad # stores ∂g/∂q @ x=5, y=4
array(121.)
>>> x.grad # stores dg/dx @ x=5, y=4
array(484.)
>>> y.grad # stores dg/dy @ x=5, y=4
array(5203.)