Chain Rule
In the machine learning world, we will often deal with functions that are more complex than simple polynomial, exponential, or sinusoidal functions. Most of the time, functions will be composite, meaning that one function will be located inside another function (which might also be located within another function). The functions
This material introduces a simple method for computing derivatives of composite functions: the so-called chain rule.
Basics of the Chain Rule
The chain rule can become unruly from a notational point of view when using the Leibniz notation for the derivative:
Given the composition of the function
the chain rule states that the derivative of the composite function with respect to
Using the
Example Calculation Using the Chain Rule
Let’s jump to an example immediately to make sure that we are not confused by this notation. Consider the following functions:
The derivatives of
According to the chain rule, this is all we need to compute the derivative of
Plugging in for
As an exercise, write
Representing the Chain Rule Using Leibniz Notation
We will ultimately need to make use of the chain rule generalized to multivariable functions. For this, Leibniz notation is extremely valuable. Recall that we write the partial derivative of
Here,
This is the notation that we will use moving forward, especially as we begin to work with partial derivatives of multivariable functions. This simple chain rule is also sufficient for generalizing to an arbitratily-long sequence of compositions.
Reading Comprehension: Proof of Chain Rule With Multiple Composite Functions
Use the equation
where
Hint: Consider one composition at a time. In other words, what is
One final note to help clarify the vertical-bar notation used above. If we wanted to compute the derivative of
which, of course, is the same as writing
To be clear,
Reading Comprehension: Chain Rule With a Single Variable Function
Calculate the derivative with respect to
First, do this using the chain rule. Then do it by expanding out the function and using just the power rule. Confirm that both derivatives are equivalent.
The Chain Rule for Multivariable Functions
The case of composing a single-variable function with a multivariable one is quite simple for extending the chain rule with partial derivatives. Take the single-variable function
You will also encounter more complicated instances, in which
Again, this can be generalized to accommodate an arbitrary number of dependent variables. So, for the function
This should make sense once dissected — we want to describe how varying
A Simple Example
Given the following functions, we will calculate
According to the chain rule provided above, the derivatives needed to compute
We can simply plug these values into the expression for the chain rule for a function of multiple dependent variables, and we will have computed the derivatives of
Autodifferentiation and the Chain Rule
Autodifferentiation libraries, like MyGrad, naturally use the chain rule to compute derivatives of composite functions. See that it reproduces the exact same values for derivatives as indicated above.
# Using MyGrad to evaluate the partial derivatives
# of a composite multivariable function
import mygrad as mg
# Initializes x and y as tensors
>>> x = mg.tensor(3)
>>> y = mg.tensor(1)
>>> p = y * x ** 2
>>> q = 2 * x + y
>>> g = p ** 2 - q ** 3
# Computes the derivatives of g with respect to all
# variables that it depends on
>>> g.backward()
>>> p.grad # stores ∂g/∂p @ x=3, y=1
array(18.)
>>> q.grad # stores ∂g/∂q @ x=3, y=1
array(-147.)
>>> x.grad # stores dg/dx @ x=3, y=1
array(-186.)
>>> y.grad # stores dg/dy @ x=3, y=1
array(15.)
Reading Comprehension: Chain Rule With a Multivariable Function
For
calculate
Reading Comprehension: Chain Rule With a Multivariable Function With MyGrad
Calculate the same partial derivative from the previous question (Chain Rule With a Multivariable Function), but this time, compute it using MyGrad. Verify that this gives the same result as doing the math by hand.
Reading Comprehension Exercise Solutions
Proof of Chain Rule With Multiple Composite Functions: Solution
We are given that
Performing one iteration of the chain rule on the given function, we get that
Performing the chain rule on
Substituting this into the previous equation, we get that
We can keep repeating this process and we will find that
Chain Rule With a Multivariable Function: Solution
The multivariable chain rule states that
We can now find the partial derivatives of
Finally, the partial derivatives of
Therefore
Chain Rule With a Multivariable Function With MyGrad: Solution
# Using MyGrad to evaluate the partial derivatives of a multivariable function
import mygrad as mg
# Initializes x and y as tensors
>>> x = mg.tensor(5)
>>> y = mg.tensor(4)
>>> p = 4 * y - x
>>> q = 3 * x * y - 4 * y
>>> g = p ** 2 * q
# Computes the derivatives of g with respect to all
# variables that it depends on
>>> g.backward()
>>> p.grad # stores ∂g/∂p @ x=5, y=4
array(968.)
>>> q.grad # stores ∂g/∂q @ x=5, y=4
array(121.)
>>> x.grad # stores dg/dx @ x=5, y=4
array(484.)
>>> y.grad # stores dg/dy @ x=5, y=4
array(5203.)