What is the difference between forward and reverse mode automatic differentiation?
TLDR; Mathematically, forward and reverse mode differentiation differ only in what order we choose to compute a sequence of matrix products. In practice, reverse mode differentiation is more complicated and should always-and-only be used for functions with many inputs and few outputs. Otherwise use forward mode differentiation.
Consider a computer program
We are interested in computing the output
where
By associativity of matrix multiplication, i.e.
The first form, where we evaluate the expression from right to left, is called forward mode differentiation. The second form, where we evaluate the expression from left to right, is called reverse mode differentiation. In both cases, when we evaluate the full gradient, we do not need to build all these Jacobian matrices explicitly (which could be costly for large sizes). Instead, it is sufficient to be able to compute the Jacobian-vector products for forward mode and vector-Jacobian products for reverse mode.
Forward mode differentiation
The pushforward function (or Jacobian-vector product) of a differentiable function
where
Since both the gradients and the states can be computed at the same time, a sequential program
- Initialize the state
and the seed given as an identity matrix of the same size as . - For
, compute Now and are no longer needed and can be discarded. - Return the full gradient
.
Here we assume that the computer has access to the pushforward functions of all the elemental functions
At first glance, forward mode differentiation seems simple and efficient. It does not require storing all previous states. However, for use cases where the input of the program is high-dimensional and the output is low dimensional, it can be computationally expensive. This is often the case in deep learning, where the input
Forward mode AD in 15 lines of Julia code
A pleasant way of implementing forward mode AD is to define a dual number type
import Base: +, * # The functions + and * are not "our" functions
struct Dual
x
xdot
end
Dual(x::Real) = Dual(x, 0) # Convert real number to dual with xdot=0
gradient(f, x) = f(Dual(x, 1)).xdot # Gradient is obtained with seed xdot=1
+(a::Dual, b::Dual) = Dual(a.x + b.x, a.xdot + b.xdot)
+(a::Dual, b::Real) = a + Dual(b)
+(a::Real, b::Dual) = Dual(a) + b
*(a::Dual, b::Dual) = Dual(a.x * b.x, a.xdot * b.x + a.x * b.xdot) # Product rule
*(a::Dual, b::Real) = a * Dual(b)
*(a::Real, b::Dual) = Dual(a) * b
gradient(x -> x * x * x, 4) # Returns 48
gradient(x -> 3x + 5, 2) # Returns 3
Furthermore, when inspecting the generated LLVM code, we see that it is able to infer the return value 3
at compile-time (even before the code is run with the input value 2
!):
@code_llvm gradient(x -> 3x + 5, 2)
; Function Signature: gradient(Main.var"#31#32", Int64)
; @ REPL[15]:1 within `gradient`
define i64 @julia_gradient_2898(i64 signext %"x::Int64") #0 {
top:
; ┌ @ REPL[29]:1 within `#31`
; │┌ @ REPL[10]:1 within `+` @ REPL[17]:1 @ int.jl:87
ret i64 3
; └└
}
Gradients of large programs can thus be expected to generate similar efficient code.
Reverse mode differentiation
The pullback function (or vector-Jacobian product) of a differentiable function
where
For convenience, we also define the (partially applied) pullback function at a given state
The canonical reverse mode differentiation algorithm is implemented using two for-loops; a forward pass to compute (and store!) the states, and a reverse pass to compute the gradients:
- Assign the initial state
. - Forward pass: for
, compute and store the pullback for later use (this may require storing the full state so that can be called). - Initialize the final adjoint seed
as an identity matrix of the same size as the output vector (typically and are both scalars). - Backward pass: for decreasing
, compute Now and are no longer needed and can be discarded. - Return the full gradient
.
For high-dimensional inputs and low-dimensional outputs, reverse mode differentiation is the go-to method of choice. However, the double for-loops in opposite order ("forward" and "back-propagation") and the requirement to store all the intermediate states of the forward pass can cause quite some headaches (and computer memory issues). Multiple strategies extist to mitigate these issues:
- Checkpointing: If storing all the states
takes up too much space, we can store every 10th state and recompute the 9 missing states between the current and next stored state when needed. - Reverse accumulation (for the bravehearted): Do a forward pass to compute
, but do not store any intermediate states. In the backward pass, compute alongside . However, the program components may be badly conditioned or even non-invertible, in which case this method should not be used.
How "automatic" is automatic differentiation?
AD engines work by decomposing programs into mathematical building blocks
which is probably implemented as some variant of
where
which, if the coefficients properly merged, gives the same expression as
WARNING
Numerically speaking, this rule is actually wrong for
Similar arguments can be made for larger algorithms, such as solving a linear system
For forward mode, it is probably fine not to implement the rule for iterative solvers. For reverse mode differentiation of a linear solve using the conjugate gradient method for a symmetric positive definite
since
Finally, the gradient of a program that solves a (partial) differential equation might (or might not) be better computed by obtaining a mathematical equation for the gradient of the exact continous solution and then discretize, instead of discretizing the equation and then differentiate.
Conclusion
While many programs can be differentiated using a naive AD engine knowing only generic rules for elementary functions, large programs that are costly to evaluate should be analyzed for potential performance improvements. This is especially important for reverse mode differentiation, where it can also be rewarding to choose a checkpointing strategy etc. The users should also consider whether they are interested in the exact gradient of the numerical implementation, or whether the gradient of the mathematical function the program approximates can be computed more efficiently.
See also
- ChainRules.jl documentation: Many nice explanations
- The SciML book: A book on scientific machine learning, including AD of differential equations
- Automatic differentiation from scratch: Nice example of forward mode AD in Julia