Skip to content

What is the difference between forward and reverse mode automatic differentiation?

TLDR; Mathematically, forward and reverse mode differentiation differ only in what order we choose to compute a sequence of matrix products. In practice, reverse mode differentiation is more complicated and should always-and-only be used for functions with many inputs and few outputs. Otherwise use forward mode differentiation.

Consider a computer program f=fNf1 composed of elemental mathematical building blocks fn and intermediate vector-valued states

xn+1=fn+1(xn).

We are interested in computing the output y=f(x) and its gradient dfdx(x) for a given input x. The chain rule states that

dfdx=dfNdxN1dfN1dxN2df2dx1df1dx,

where dfn+1dxn=dfn+1dxn(xn) is the Jacobian matrix of fn+1 at the point xn. The size is size(xn+1)×size(xn). In particular, if the input size is larger than the output size, the Jacobian is a "flat" matrix (and "tall" in the other case).

By associativity of matrix multiplication, i.e. (AB)C=A(BC) for all matrices A, B, and C, we can choose to write the chain rule in two ways:

dfdx=dfNdxN1(df3dx2(df2dx1(df1dx)))=(((dfNdxN1)dfN1dxN2)dfN2dxN3)df1dx.

The first form, where we evaluate the expression from right to left, is called forward mode differentiation. The second form, where we evaluate the expression from left to right, is called reverse mode differentiation. In both cases, when we evaluate the full gradient, we do not need to build all these Jacobian matrices explicitly (which could be costly for large sizes). Instead, it is sufficient to be able to compute the Jacobian-vector products for forward mode and vector-Jacobian products for reverse mode.

Forward mode differentiation

The pushforward function (or Jacobian-vector product) of a differentiable function f is defined as

f˙(x,x˙)=dfdx(x)x˙,

where x is a given state vector and x˙ is a gradient seed being "pushed forward" in the computational chain. The seed x˙ can be a column vector or a matrix (collection of column vectors) of the same size as x.

Since both the gradients and the states can be computed at the same time, a sequential program f=fNf1 can be executed with a single for-loop without saving any intermediate states:

  1. Initialize the state x0=x and the seed x˙0=I(x) given as an identity matrix of the same size as x.
  2. For n{0,,N1}, computexn+1=fn+1(xn),x˙n+1=f˙n+1(xn,x˙n).Now xn and x˙n are no longer needed and can be discarded.
  3. Return the full gradient x˙N=dfdx(x).

Here we assume that the computer has access to the pushforward functions of all the elemental functions (fn)n=1N in the program (e.g. the rule cos˙(x,x˙)=sin(x)x˙ must be hard-coded somewhere if you want to use cos(x) in your program).

At first glance, forward mode differentiation seems simple and efficient. It does not require storing all previous states. However, for use cases where the input of the program is high-dimensional and the output is low dimensional, it can be computationally expensive. This is often the case in deep learning, where the input x contains millions of neural network weights and the output y=f(x) is a scalar loss function value. In this case, at each step in the program but the last, x˙n is a large matrix of size millions-times-millions, even though the final gradient dfdx is just a vector of the same size as x.

Forward mode AD in 15 lines of Julia code

A pleasant way of implementing forward mode AD is to define a dual number type d=(x,x˙) with the fundamental property f(d)=(f(x),f˙(x,x˙)) for all differentiable functions fC1. In Julia, by adding methods to basic arithmetic functions such as addition and multiplication for dual numbers, many programs can be made differentiable without any modification at all. Example (for scalar functions composed of additions and multiplications):

julia
import Base: +, * # The functions + and * are not "our" functions
struct Dual
    x
    xdot
end
Dual(x::Real) = Dual(x, 0) # Convert real number to dual with xdot=0
gradient(f, x) = f(Dual(x, 1)).xdot # Gradient is obtained with seed xdot=1
+(a::Dual, b::Dual) = Dual(a.x + b.x, a.xdot + b.xdot)
+(a::Dual, b::Real) = a + Dual(b)
+(a::Real, b::Dual) = Dual(a) + b
*(a::Dual, b::Dual) = Dual(a.x * b.x, a.xdot * b.x + a.x * b.xdot) # Product rule
*(a::Dual, b::Real) = a * Dual(b)
*(a::Real, b::Dual) = Dual(a) * b
gradient(x -> x * x * x, 4) # Returns 48
gradient(x -> 3x + 5, 2) # Returns 3

Furthermore, when inspecting the generated LLVM code, we see that it is able to infer the return value 3 at compile-time (even before the code is run with the input value 2!):

julia
@code_llvm gradient(x -> 3x + 5, 2)
llvm
; Function Signature: gradient(Main.var"#31#32", Int64)
;  @ REPL[15]:1 within `gradient`
define i64 @julia_gradient_2898(i64 signext %"x::Int64") #0 {
top:
; ┌ @ REPL[29]:1 within `#31`
; │┌ @ REPL[10]:1 within `+` @ REPL[17]:1 @ int.jl:87
    ret i64 3
; └└
}

Gradients of large programs can thus be expected to generate similar efficient code.

Reverse mode differentiation

The pullback function (or vector-Jacobian product) of a differentiable function f is defined as

f¯(x,y¯)=y¯dfdx(x).

where x is a given state vector and y¯ is an adjoint variable being "pulled back" in the computational chain. The adjoint variable y¯ should be a row vector or a matrix (collection of row vectors) of the same size as the column vector y=f(x).

For convenience, we also define the (partially applied) pullback function at a given state x as

f¯(x):y¯f¯(x,y¯).

The canonical reverse mode differentiation algorithm is implemented using two for-loops; a forward pass to compute (and store!) the states, and a reverse pass to compute the gradients:

  1. Assign the initial state x0=x.
  2. Forward pass: for n{0,,N1}, computexn+1=fn(xn),gn+1=f¯n+1(xn),and store the pullback gn+1 for later use (this may require storing the full state xn so that gn+1 can be called).
  3. Initialize the final adjoint seed x¯N=I(xN) as an identity matrix of the same size as the output vector xN (typically xN and x¯N=1 are both scalars).
  4. Backward pass: for decreasing n{N,,1}, computex¯n1=gn(x¯n).Now x¯n and gn are no longer needed and can be discarded.
  5. Return the full gradient x¯0=dfdx(x).

For high-dimensional inputs and low-dimensional outputs, reverse mode differentiation is the go-to method of choice. However, the double for-loops in opposite order ("forward" and "back-propagation") and the requirement to store all the intermediate states of the forward pass can cause quite some headaches (and computer memory issues). Multiple strategies extist to mitigate these issues:

  • Checkpointing: If storing all the states (xn)n=0N1 takes up too much space, we can store every 10th state and recompute the 9 missing states between the current and next stored state when needed.
  • Reverse accumulation (for the bravehearted): Do a forward pass to compute xN, but do not store any intermediate states. In the backward pass, compute xn1=fn1(xn) alongside x¯n1=gn(x¯n). However, the program components (fn)n=1N may be badly conditioned or even non-invertible, in which case this method should not be used.

How "automatic" is automatic differentiation?

AD engines work by decomposing programs into mathematical building blocks (fn)n=1N that it knows how to differentiate. This knowledge needs to be hard-coded by humans. The ChainRules.jl ecosystem in Julia provides a nice framework for specifying pushforward and pullback rules, along with pre-defined rules for many common functions. In computers, most mathematical functions are implemented as some form of truncated power series expansions, for example the exponential function

exp(x)=1+x+x22+x36+

which is probably implemented as some variant of

sN(x)=n=0Nxnn!,

where N depends on the desired accuracy. A naive AD engine that only knows how to differentiate polynomials might decide to compute the derivative

dsNdx(x)=n=1Nnxn1n!=sN1(x)

which, if the coefficients properly merged, gives the same expression as sN(x) but with precision N1 instead of N. In addition, during a forward mode AD pass, the AD engine would compute both y=sN(x) and dsNdx(x), not knowing that it could in fact reuse the value y in y˙=y. With a hard-coded pushforward rule for sN(x), we could tell the AD system to compute y=sN(x) once, and then return (y,yx˙). Since a pushforward rule for sN exists in the rule table, the AD system would decide not to decompose sN further, and instead use the rule directly.

WARNING

Numerically speaking, this rule is actually wrong for sN, but correct for exp which sN is supposed to approximate. If we need exact gradients of the numerical implementation, the pushforward should return (sN(x),sN1(x)x˙) instead of (y,yx˙) with y=sN(x).

Similar arguments can be made for larger algorithms, such as solving a linear system Ay=x using an iterative solver. The function f:xA1x is then implemented using a for-loop where we compute matrix-vector products such as Axn for intermediate guesses xn. The gradient of f is given by dfdx(x)=A1, and so the pushforward rule is f˙(x,x˙)=A1x˙=f(x˙). Instead of differentiating the entire for-loop, we could just do a new linear solve to compute f˙(x,x˙), possibly converging in a different number of iterations.

For forward mode, it is probably fine not to implement the rule for iterative solvers. For reverse mode differentiation of a linear solve using the conjugate gradient method for a symmetric positive definite A, creating two for-loops with a forward pass and a backward pass would be disastrous, when we know that the pullback is

f¯(x,y¯)=y¯dfdx(x)=y¯A1=(A1y¯T)T=f(y¯T)T,

since (A1)T=A1. We could just run the linear solver twice: once for y=f(x), and once for x¯=f(y¯T)T.

Finally, the gradient of a program that solves a (partial) differential equation might (or might not) be better computed by obtaining a mathematical equation for the gradient of the exact continous solution and then discretize, instead of discretizing the equation and then differentiate.

Conclusion

While many programs can be differentiated using a naive AD engine knowing only generic rules for elementary functions, large programs that are costly to evaluate should be analyzed for potential performance improvements. This is especially important for reverse mode differentiation, where it can also be rewarding to choose a checkpointing strategy etc. The users should also consider whether they are interested in the exact gradient of the numerical implementation, or whether the gradient of the mathematical function the program approximates can be computed more efficiently.

See also