In the previous Section we detailed how we can derive derivative formulae for any function constructed from elementary functions and operations, and how derivatives of such functions are themselves constructed from elementary functions/operations. These facts have far-reaching consequences for the practical computing of derivatives - allowing us to construct a very effective derivative calculator called Automatic Differentiation which we detail here.
More specifically we describe how one can quickly code up the so-called forward mode of Automatic Differentiation, a natural and direct implementation of the method for calculating derivatives 'by hand' using a computation graph as discussed in the previous Section. Leveraging closed form derivative rules - instead of e.g., numerical approximation - the Automatic Differentiator it is typically more accurate than the Numerical Differentiation calculator discussed in the second Section of this series, and has no parameters that require tuning.
While the AD calculator employs the derivative rules in precisely the same manner we have seen previously, it will not produce an algebraic description of a derivative but its programmatic analog: a program-based function / subroutine. Remember as we saw in the previous Section, in computing derivatives we always construct a function evaluation $g(w)$ and its derivative $\frac{\mathrm{d}}{\mathrm{d}w}g(w)$ simultaneously - and this will be fundamental aspect of our implementation here.
Since the input $w$ represents the simplest function we could possibly wish to differentiate (i.e., $g(w) = w$), and since every other mathematical function is built by combining elementary functions / operations involving $w$, a logical first step in building an AD calculator is to simply implement a numerical version of $w$. Because we need to keep track of both the evaluation and its derivative value this implementation needs to be a tuple - a set of two values - both the evaluation of $w$ at a given input and its derivative. Now obviously since we always have that $\frac{\mathrm{d}}{\mathrm{d}w}w = 1$ the derivative value never changes, but we need to keep track of it nonetheless (since we will build in other derivative rules on top of $w$).
Below we define a simple class called MyTuple
that implements the input variable $w$ (note: for those wanting a good introduction to Python classes in the context of mathematical functions see e.g., this excellent book.). Instances of this class are a tuple containing val
- the value of $w$ which is pre-set to val = 0
but will be adjusted to any user-defined value - and der
- the derivative of $w$ which is initialized at der = 1
.
class MyTuple:
'''
The basic object representing the input variable 'w'
represents the core of our AD calculator. An instance
of this class is a tuple containining one function/derivative
evaluation of the variable 'w'. Because it is meant to
represent the simple variable 'w' the derivative 'der' is
preset to 1. The value 'val' can be set to 0 by default.
'''
def __init__(self,**kwargs):
# variables for the value (val) and derivative (der) of our input function
self.val = 0
self.der = 1
# re-assign these default values
if 'val' in kwargs:
self.val = kwargs['val']
if 'der' in kwargs:
self.der = kwargs['der']
With our base constructed we can quickly code up the derivative rules for elementary functions contained in Table 1 of the previous Section, which we repeat here for convenience.
elementary function | equation | derivative |
---|---|---|
constant | $c$ | $0$ |
monomial (degree $p\neq 0$) | $w^p$ | $pw^{p-1}$ |
sine | $\text{sin}(w)$ | $\text{cos}(w)$ |
cosine | $\text{cos}(w)$ | $-\text{sin}(w)$ |
exponential | $e^w$ | $e^w$ |
logarithm | $\text{log}(w)$ | $\frac{1}{w}$ |
hyperbloic tangent | $\text{tanh}(w)$ | $1 - \text{tanh}^2(w)$ |
rectified linear unit (ReLU) | $\text{max}\left(0,w\right)$ | $\begin{cases}0 & w\leq0\\1 & w>0\end{cases}$ |
Since our variable $w$ (our MyTuple
object) keeps track of both the function and derivative values, all we need to do in order to create one of these rules as a Python
function is define how the elementary rule transforms the function and derivative evaluation.
For example, we have a Python function for the sinusoid update rule in the next cell. Notice this is almost a direct translation of the corresponding rule to code: we first record how the function itself (sin
) should affect the val
attribute of an input, and then how its derivative should affect the corresponding der
attribute.
# our implementation of the sinusoid rule from Table 1
def sin(a):
# Create output evaluation and derivative object
b = MyTuple()
# Produce new function value
b.val = np.sin(a.val)
# Produce new derivative value - we need to use the chain rule here!
b.der = np.cos(a.val)*a.der
# Return updated object
return b
Here we input $a$, an MyTuple
object with a current function and derivative value, and create a new instance to contain their updates called $b$. To get the new function update
b.val = np.sin(a.val)
we simply pass the current value through a sinusoid. The corresponding derivative value update
b.der = np.cos(a.val)*a.der
involves two parts. The sinusoid derivative rule alone would have us update the derivative value by simply passing a.val
through cosine. But remember - as discussed in the previous Section - that every time we apply an elementary derivative formula we must apply the chain rule as well. This is why we multiply np.cos(a.val)
by a.der
in the update.
We can now test our sinusoid function over a few input points, as is done in the next cell.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: sin(w)
# initialize our AutoDiff object at each point
a1 = MyTuple(val = 0); a2 = MyTuple(val = 0.5)
# evaluate
result1 = g(a1); result2 = g(a2)
# print results
print ('function value at ' + str(0) + ' = ' + str(result1.val))
print ('derivaive value at ' + str(0) + ' = ' + str(result1.der))
print ('function value at ' + str(0.5) + ' = ' + str(result2.val))
print ('derivaive value at ' + str(0.5) + ' = ' + str(result2.der))
This looks fine - so lets evaluate over a large range of inputs and visualize both the function and derivative values. In the next Python cell we use a short custom plotting function that performs the above evaluations over a large range of input values, and then plots the resulting function/derivative values.
# define a function using our AD components
g = lambda w: sin(w)
# use custom plotter to evaluate function/derivative over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
And indeed this is correct: the function evaluation is sine and the derivative evaluation is cosine over the entire input range.
We can just as easily define a cosine function as well, and do so in the next Python cell.
# our implementation of the cosine rule from Table 1
def cos(a):
# Create output evaluation and derivative object
b = MyTuple()
# Produce new function value
b.val = np.cos(a.val)
# Produce new derivative value - we need to use the chain rule here!
b.der = -np.sin(a.val)*a.der
# Return updated object
return b
We can define a power rule precisely as done with the sinusoid function in the previous example: updating the current function evaluation using the elementary function, and the current derivative value using the corresponding derivative rule from Table 1. Once again we must include an instance of the chain rule with our derivative update.
# our implementation of the power rule from Table 1
def power(a,n):
# Create output evaluation and derivative object
b = MyTuple()
# Produce new function value
b.val = a.val**n
# Produce new derivative value - we need to use the chain rule here!
b.der = n*(a.val**(n-1))*a.der
# Return updated object
return b
And we can test out this function by evaluating/plotting over a large range of inputs, say the third degree monomial $w^2$.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: power(w,2)
# use the custom plotting function above to plot function/derivative over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
Again everything looks good.
As with the previous examples we can define the $\text{tanh}$ function, making sure to include the chain rule with our update of the derivative value.
# our implementation of the power rule from Table 1
def tanh(a):
# Create output evaluation and derivative object
b = MyTuple()
# Produce new function value
b.val = np.tanh(a.val)
# Produce new derivative value
b.der = (1 - np.tanh(a.val)**2)*a.der
# Return updated object
return b
And we can test our new tanh function over a range of values.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: tanh(w)
# use the custom plotting function above to plot function/derivative over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
This is correct! We can of course verify that this is correct by plotting the function and derivative equations given in the table. We do this in the next cell using a custom plotting function.
# define function and its derivative equations using numpy
g = lambda w: np.tanh(w)
dgdw = lambda w: (1 - np.tanh(w)**2)
# plot both
calclib.plotter.derval_eq_plot(g,dgdw)
Mirroring the previous examples we can code up the derivative rule for log as follows.
# our implementation of the sinusoid rule from Table 1
def log(a):
# Create output evaluation and derivative object
b = MyTuple()
# Produce new function value
b.val = np.log(a.val)
# Produce new derivative value
b.der = (1/a.val)*a.der
# Return updated object
return b
And quickly testing it out, we see that it indeed works.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: log(w)
w = np.linspace(0.01,2.1000)
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g,w=w)
We can very easily continue, defining a function for each elementary derivative rule precisely as we have done here with these three examples. The only wrinkle to remember with each is that we must include an instance of the chain rule with each derivative update from Table 1.
In the previous Section we described derivative rules for each elementary operation, which we repeat below for convenience.
operation | equation | derivative rule |
---|---|---|
addition of a constant $c$ | $g(w) + c$ | $\frac{\mathrm{d}}{\mathrm{d}w}\left(g(w) + c\right)= \frac{\mathrm{d}}{\mathrm{d}w}g(w)$ |
multiplication by a constant $c$ | $cg(w)$ | $\frac{\mathrm{d}}{\mathrm{d}w}\left(cg(w)\right)= c\frac{\mathrm{d}}{\mathrm{d}w}g(w)$ |
addition of functions (often called the summation rule) | $f(w) + g(w)$ | $\frac{\mathrm{d}}{\mathrm{d}w}(f(w) + g(w))= \frac{\mathrm{d}}{\mathrm{d}w}f(w) + \frac{\mathrm{d}}{\mathrm{d}w}g(w)$ |
multiplication of functions (often called the product rule) | $f(w)g(w)$ | $\frac{\mathrm{d}}{\mathrm{d}w}(f(w)\cdot g(w))= \left(\frac{\mathrm{d}}{\mathrm{d}w}f(w)\right)\cdot g(w) + f(w)\cdot \left(\frac{\mathrm{d}}{\mathrm{d}w}g(w)\right)$ |
composition of functions (often called the chain rule) | $f(g(w))$ | $\frac{\mathrm{d}}{\mathrm{d}w}(f(g(w)))= \frac{\mathrm{d}}{\mathrm{d}g}f(g) \cdot \frac{\mathrm{d}}{\mathrm{d}w}g(w)$ |
maximum of two functions | $\text{max}(f(w),\,g(w))$ | $\frac{\mathrm{d}}{\mathrm{d}w}(\text{max}(f(w),\,g(w))) = \begin{cases}\frac{\mathrm{d}}{\mathrm{d}w}f\left(w\right) & \text{if}\,\,\,f\left(w\right)\geq g\left(w\right)\\\frac{\mathrm{d}}{\mathrm{d}w}g\left(w\right) & \text{otherwise}\end{cases}$ |
As with the derivative formulae for elementary functions, implementing these rules means providing - in each case - an update for the function and derivative value. We walk through several examples below.
In addition to making sure we update function and derivative values, when coding up the addition rules we may as well knock out both of the addition rules in Table 2: defining the derivative of a function and constant, and two functions. A simple switch or if/else statement is used below to sort between these two cases.
# our implementation of the addition rules from Table 2
def add(a,b):
# Create output evaluation and derivative object
c = MyTuple()
# switch to determine if a or b is a constant
if type(a) != MyTuple:
c.val = a + b.val
c.der = b.der
elif type(b) != MyTuple:
c.val = a.val + b
c.der = a.der
else: # both inputs are MyTuple objects, i.e., functions
c.val = a.val + b.val
c.der = a.der + b.der
# Return updated object
return c
With the addition rules taken care of above we can try it out using two of the elementary function rules coded in the previous subsection. In particular we test with the sum
$$ \text{sin}(w) + w $$since both the sinusoid and the power function update rules were coded previously.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: add(sin(w),w)
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
Notice here that the computation graph for our input function - which includes derivative rules for both elementary functions and operations - is implicitly constructed and computed on when we pass a MyTuple object through the function add(sin(w),w)
. In other words, while we do not have the computation graph itself we are traversing it to construct the function/derivative values just as we did by hand in the previous Section.
Since we also have several multiplication rules - what to do with the product of a constant and a function, and two functions - we can also wrap both rules up into one Python function, using a switch or if/else to distinguish between them.
# our implementation of the addition rules from Table 2
def multiply(a,b):
# Create output evaluation and derivative object
c = MyTuple()
# switch to determine if a or b is a constant
if type(a) != MyTuple:
c.val = a*b.val
c.der = a*b.der
elif type(b) != MyTuple:
c.val = a.val*b
c.der = a.der*b
else: # both inputs are MyTuple objects i.e., functions
c.val = a.val*b.val
c.der = a.der*b.val + a.val*b.der # product rule
# Return updated object
return c
With the multiplication rules taken care of above we can try it out using two of the elementary function rules coded in the previous subsection. In particular we test with the sum
$$ \text{sin}(w)\times w^2 $$since both the sinusoid and the power function update rules were coded previously.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: multiply(sin(w),power(w,2))
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
Notice here that the computation graph for our input function - which includes derivative rules for both elementary functions and operations - is implicitly constructed and computed on when we pass a MyTuple object through the function multiply(sin(w),power(w,2))
.
Its easy to check that this is correct by plotting the derivative equation itself, which using the elementary function/operation rules can be written as
$$ \frac{\mathrm{d}}{\mathrm{d}w}g(w) = 2\text{sin}(w)w + \text{cos}(w)w^2 $$We plot this equation directly, along with the original function, in the next Python cell.
# define function and its derivative equations
g = lambda w: np.sin(w)*w**2
dgdw = lambda w: 2*np.sin(w)*w + np.cos(w)*w**2
# plot both
calclib.plotter.derval_eq_plot(g,dgdw)
With our current AD calculator setup we have just seen that in order to compute the derivative of
$$ g(w) = \text{sin}(w) + w $$we use the somewhat clunky Python notation
add(sin(w),w)
to invoke the summation derivative rule, and compute the derivative of the function. In this subsection we briefly discuss how one can overload operators in Python to make the call above look more similar to the algebraic form of the function.
Python reserves symbols like +
and *
to denote a Python functions that perform standard addition and multiplication on floating point objects. This means that if we try to use one of these symbols on our MyTuple object directly we will receive an error like
unsupported operand type(s) for +: 'MyTuple' and 'MyTuple'
because our objects differ from the intended input to Python's default addition function, the one assigned to the +
operator. We try this in the next Python cell, and indeed we provoke an error.
# create two MyTuple objects and try to use Python's built in function assigned to the + operator on them
a = MyTuple(); b = MyTuple();
a + b
We have already defined an addition function for MyTuple objects, one that updates both function and derivative values, and can force Python to use this function whenever we write use symbol +
in the context of MyTuple objects. This is called operator overloading, in short we re-define the function Python uses when it sees a particular operator like +
.
Because we want the operator +
to behave differently for MyTuple objects we must overload it in the class definition of MyTuple. So, we can go back and add our add
function in the MyTuple class definition. In order to tell Python to use this function to overload the +
operator we use the special function name __add__
, as shown in the abbreviated version of the add
function below.
# our implementation of the addition rules from Table 2
def __add__(self,a,b):
# Create output evaluation and derivative object
c = MyTuple()
.
.
.
# Return updated object
return c
We can also do this on the fly, appending our add
function to the class definition, and we do this in the next Python cell.
# this next line overloads the addition operator for our MyTuple objects, or in other words adds the 'add' function to our MyTuple class definition on the fly
MyTuple.__add__ = add
# overload the reverse direction so that a + b = b + a
MyTuple.__radd__ = add
Note we also overload the operator __radd__
above as well, because in Python different functions can be assigned to +
depending on the ordering of the two elements being operated on. i.e., we can assign different functions to +
for each scenario
__add__: assigns operation to a + b
__radd__: assigns operation to b + a
if we so desired. Of course we want Python to interpret +
with our objects so that these two are equal
a + b = b + a
Since we want to treat both the same way, and so overload __radd__
with our add
function as well.
Now we can use the +
symbol with our MyTuple objects, and Python will employ our own add
function to combine the two objects. We demonstrate this in the next Python cell.
# create two MyTuple objects and try to use Python's built in function assigned to the + operator on them
a = MyTuple(); b = MyTuple();
a + b
b + a
We can do precisely the same thing with other natural Python operators reserved for multiplication *
, subtraction -
, raising to a power **
, etc., You can see a full list of operators that can be overwritten here. The more of these we overload appropriately the more user-friendly our AD calculator becomes.
Since we have already made functions for multiplying and raising MyType objects to a power, we overload these two operators on the fly in the next Python cell.
# create two MyTuple objects and try to use Python's built in function assigned to the * operator on them
MyTuple.__mul__ = multiply
# overload the 'reverse multiplication' so that a*b = b*a
MyTuple.__rmul__ = multiply
# create two MyTuple objects and try to use Python's built in function assigned to the ** operator on them
MyTuple.__pow__ = power
With these operators overloaded we can write out several of the previous examples more naturally, which we do in the next few Python cells.
For example we calculated the derivative above as
$$ g(w) = \text{sin}(w)w^2 $$which we first had to write as
multiply(sin(w),power(w,2))
Now we can write this derivative calculation much more naturally as
sin(w)*w**2
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: sin(w)*w**2
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g)
In the previous Section we also computed (by hand) and plotted the derivatives of the following by hand, and plotted the derivative function explicitly.
\begin{array} \ g(w) = \text{sin}(w^3) \\ g(w) = \text{tanh}(w)\text{cos}(w) + \text{log}(w) \\ g(w) = \frac{\text{cos}(20w)}{w^2 + 1} \\ \end{array}In the next three Python cells we use our AD calculator to compute these derivatives, and plot the results.
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: sin(w**3)
w = np.linspace(-3,3,1000)
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g,w=w)
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: tanh(w)*cos(w) + log(w)
w = np.linspace(0.01,3,1000)
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g,w=w)
# create instance of our function to differentiate - notice this uses our homemade sine function not numpy's
g = lambda w: cos(20*w)*(w**2 + 1)**(-1)
w = np.linspace(-3,3,1000)
# use the custom plotting function above to plot over a large range of inputs
calclib.plotter.ad_derval_plot(MyTuple,g,w=w)
Correct again!
In constructing our AD calculator here we made a number of engineering choices which - in short - lead to a light weight piece of code that we can develop and use rather quickly, and that is (we hope) fairly easy to understand. However there are other engineering choices one can make, some of which are merely stylistic while others present trade-offs in terms of usability and extendibility. We review a few of the more crucial engineering choices below.
Deciding on a direction for information flow - should we compute derivatives sweeping forward or backward through the computation graph?
The computation graph is a powerful tool for recursively describing how a mathematical function is constructed, how it is evaluated, and how its derivative is formed. In particular, in the previous Section we saw how recursively sweeping forward or backward through a computation graph breaks down large derivative calculations into a long sequence of much smaller more manageable ones. When performing these calculations by hand our main concerns were in computing each step accurately, and keeping the many computations organized. Thankfully because recursive algorithms are so naturally dealt with using for loops / while loops, the computation graph provides a bridge for thinking about how to perform / organize computations involving mathematical functions - like e.g., derivative calculations - on a computer instead of by hand.
But which way should we perform these computations - going forward or backward through the graph? This is really up to us to decide, for the AD calculator we built here we computed derivatives going forward. This is the simpler of the two choices in terms of book keeping and recursive computation, which is why it was primary method used in the examples of the previous Section.
How do we represent the computation graph of a function?
Because AD calculations are made using the computation graph of a function $g(w)$, one engineering choice to be made is to decide how the graph will be constructed and manipulated by our AD algorithm. Essentially we have two choices: we can either construct the computation graph implicitly - as we did in the implementation we built here - or we can parse the input function $g(w)$ construct its computation graph explicitly as we did in pictures in the previous Section. The advantage of implicitly constructing the graph is that the corresponding calculator is light weight and easy to construct. On the other hand, implementing a calculator that explicitly constructs computation graphs requires additional technology (like e.g., a parser).
While the choice to implicitly represent the computation graph makes the job of implementing an AD calculator easier, it does mean that we will need to make certain adjustments in order to extend the use of the calculator to e.g., higher order derivatives. Since the sort of optimization algorithms used in machine learning / deep learning almost universally use only first and second order derivatives, this is not too damning. However this is in contrast to explicitly constructing the computation graph, which allows for immediate higher order derivative calculations. This is because such an AD calculator takes in a function to differentiate and treats it as a computation graph, and outputs the computation graph of its derivative (which can then be plugged back into the same calculator to differentiate).
Should we compute the algebraic derivative, or the derivative evaluated at user-defined input values?
The forward mode AD calculator we built here does not provide an algebraic description of a function's derivative, but a programmatic function that can be used to evaluate the function and its derivative at any set of input points. Conversely one can build an algorithm that employs the basic derivative rules to provides an algebraic derivative, but this requires the implementation of a computer algebra system. Such a derivative calculator - that deals with derivatives using symbolic computation (i.e., algebra on the computer) - is called a Symbolic Differentiator. However there are a few reasons why - at least for machine learning / deep learning applications - why Automatic Differentiation is a better choice.
First of all for our applications we only need a calculator that can provide a programmatic-based description for the derivative - i.e., one that provides precise derivative value at selected input points - which the AD calculator provides.
Secondly the AD calculator requires fewer tools to build, as it requires only basic coding methods and no computer algebra systems.
Finally, expressing derivative equations algebraically can be quite unwieldy. For example, the rather complicated looking function
\begin{equation} g(w) = \text{sin}\left(e^{\,5\text{tanh}^2(w) + w^5}\right)\text{log}\left(\frac{1}{w^2 + 1} \right)\frac{w^2 + 1}{\text{cos}(\text{sin}(w))} \end{equation}has an expansive algebraic derivative. Below are just the first few terms
\begin{equation} \frac{\mathrm{d}}{\mathrm{d}w}g(w) = -2w\,\text{sin}\left(e^{\,5\text{tanh}^2(w) + w^5}\right) \, \frac{1}{\text{cos}(\text{sin}(w))} + 2w\,\text{log}\left(\frac{1}{w^2 + 1}\right)\text{sin}\left(e^{\,5\text{tanh}^2(w) + w^5}\right)\frac{1}{\text{cos}(\text{sin}(w))} + \cdots \end{equation}And this sort of problem is exponentially worse - to the point of being a considerable computational burden - when dealing with multivariable functions. Such an example illustrates the real need to automatic simplifying of algebraic expressions as well, ideally during the differentiation process to make sure things do not get too far out of hand. The AD calculator we built here - while not algebraic method - essentially does this simplifying automatically while computing derivatives.