Skip to content

Differentiating through the code

IncompressibleNavierStokes is reverse-mode differentiable, which means that you can back-propagate gradients through the code. Two AD libraries are currently supported:

  • Zygote.jl: it is the default AD library in the Julia ecosystem and is the most widely used.

  • Enzyme.jl: currently has low coverage over the Julia programming language, however it is usually the most efficient if applicable.

Automatic differentiation with Zygote

Zygote.jl is the default choice for AD backend because it is easy to understand, compatible with most of the Julia ecosystem and good with vectorized code and BLAS. This comes at a cost however, as intermediate velocity fields need to be stored in memory for use in the backward pass. For this reason, many of the operators come in two versions: a slow differentiable allocating non-mutating variant (e.g. divergence) and fast non-differentiable non-allocating mutating variant (e.g. divergence!.)

Zygote limitation: array mutation

To make your code differentiable, you must use the differentiable versions of the operators (without the exclamation marks).

Example: Gradient of kinetic energy

To differentiate outputs of a simulation with respect to the initial conditions, make a time stepping loop composed of differentiable operations:

julia
using IncompressibleNavierStokes

ax = range(0, 1, 101)
setup = Setup(; x = (ax, ax), Re = 500.0)
psolver = default_psolver(setup)
method = RKMethods.RK44P2()
Δt = 0.01
nstep = 100
(; Iu) = setup.grid
function final_energy(u)
    stepper = create_stepper(method; setup, psolver, u, temp = nothing, t = 0.0)
    for it = 1:nstep
        stepper = timestep(method, stepper, Δt)
    end
    (; u) = stepper
    sum(abs2, u[Iu[1], 1]) / 2 + sum(abs2, u[Iu[2], 2]) / 2
end

u = random_field(setup)

using Zygote
g, = Zygote.gradient(final_energy, u)

@show size(u) size(g)
(102, 102, 2)

Now g is the gradient of final_energy with respect to the initial conditions u, and consequently has the same size.

Note that every operation in the final_energy function is non-mutating and thus differentiable.

Automatic differentiation with Enzyme

Enzyme.jl is highly-efficient and its ability to perform AD on optimized code allows Enzyme to meet or exceed the performance of state-of-the-art AD tools. The downside is that restricts the user's defined f function to not do things like require garbage collection or calls to BLAS/LAPACK. However, mutation is supported, meaning that in-place f with fully mutating non-allocating code will work with Enzyme and this will be the most efficient adjoint implementation.

Enzyme limitation: vector returns

Enzyme's autodiff function can only handle functions with scalar output. To implement pullbacks for array-valued functions, use a mutating function that returns nothing and stores its result in one of the arguments, which must be passed wrapped in a Duplicated. In IncompressibleNavierStokes, we provide enzyme_wrapper to automatically wrap the function and its arguments in the correct way.

Example: Gradient of the right-hand side

In this example we differentiate the right-hand side of the Navier-Stokes equations with respect to the velocity field u:

julia
using Enzyme
ax = range(0, 1, 101)
setup = Setup(; x = (ax, ax), Re = 500.0)
psolver = default_psolver(setup)
u = random_field(setup)
dudt = similar(u)
t = 0.0
f! = right_hand_side!
right_hand_side! (generic function with 1 method)

Notice that we are using the mutating (in-place) version of the right-hand side function. This function can not be differentiate by Zygote, which requires the slower non-mutating version of the right-hand side.

We then define the Dual part of the input and output, required to store the adjoint values:

julia
ddudt = Enzyme.make_zero(dudt) .+ 1;
du = Enzyme.make_zero(u);
102×102×2 Array{Float64, 3}:
[:, :, 1] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱  ⋮                        ⋮    
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

[:, :, 2] =
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 ⋮                        ⋮              ⋱  ⋮                        ⋮    
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  …  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0     0.0  0.0  0.0  0.0  0.0  0.0  0.0

Remember that the derivative of the output (also called the seed) has to be set to 1 in order to compute the gradient. In this case the output is the force, that we store mutating the value of dudt inside right_hand_side!.

Then we pack the parameters to be passed to right_hand_side!:

julia
params = [setup, psolver];
params_ref = Ref(params);
Base.RefValue{Vector{Any}}(Any[(grid = (xlims = ((0.0, 1.0), (0.0, 1.0)), dimension = IncompressibleNavierStokes.Dimension{2}(), N = (102, 102), Nu = ((100, 100), (100, 100)), Np = (100, 100), Iu = (CartesianIndices((2:101, 2:101)), CartesianIndices((2:101, 2:101))), Ip = CartesianIndices((2:101, 2:101)), x = ([-0.010000000000000009, 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08  …  0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0, 1.01], [-0.010000000000000009, 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08  …  0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0, 1.01]), xu = (([0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09  …  0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0, 1.01], [-0.0050000000000000044, 0.005, 0.015, 0.025, 0.035, 0.045, 0.055, 0.065, 0.07500000000000001, 0.08499999999999999  …  0.915, 0.925, 0.935, 0.945, 0.955, 0.965, 0.975, 0.985, 0.995, 1.005]), ([-0.0050000000000000044, 0.005, 0.015, 0.025, 0.035, 0.045, 0.055, 0.065, 0.07500000000000001, 0.08499999999999999  …  0.915, 0.925, 0.935, 0.945, 0.955, 0.965, 0.975, 0.985, 0.995, 1.005], [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09  …  0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0, 1.01])), xp = ([-0.0050000000000000044, 0.005, 0.015, 0.025, 0.035, 0.045, 0.055, 0.065, 0.07500000000000001, 0.08499999999999999  …  0.915, 0.925, 0.935, 0.945, 0.955, 0.965, 0.975, 0.985, 0.995, 1.005], [-0.0050000000000000044, 0.005, 0.015, 0.025, 0.035, 0.045, 0.055, 0.065, 0.07500000000000001, 0.08499999999999999  …  0.915, 0.925, 0.935, 0.945, 0.955, 0.965, 0.975, 0.985, 0.995, 1.005]), Δ = ([0.010000000000000009, 0.01, 0.01, 0.009999999999999998, 0.010000000000000002, 0.010000000000000002, 0.009999999999999995, 0.010000000000000009, 0.009999999999999995, 0.009999999999999995  …  0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009], [0.010000000000000009, 0.01, 0.01, 0.009999999999999998, 0.010000000000000002, 0.010000000000000002, 0.009999999999999995, 0.010000000000000009, 0.009999999999999995, 0.009999999999999995  …  0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009]), Δu = ([0.010000000000000005, 0.009999999999999998, 0.010000000000000002, 0.010000000000000002, 0.009999999999999995, 0.010000000000000002, 0.010000000000000002, 0.010000000000000009, 0.009999999999999981, 0.010000000000000009  …  0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.0050000000000000044], [0.010000000000000005, 0.009999999999999998, 0.010000000000000002, 0.010000000000000002, 0.009999999999999995, 0.010000000000000002, 0.010000000000000002, 0.010000000000000009, 0.009999999999999981, 0.010000000000000009  …  0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.0050000000000000044]), A = ((([1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0]), ([1.0, 0.4999999999999998, 0.4999999999999999, 0.5, 0.5000000000000003, 0.5, 0.49999999999999967, 0.5000000000000003, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5000000000000056, 0.5, 0.5, 0.5, 0.5, 0.5, 0.49999999999999445], [0.5000000000000002, 0.5000000000000001, 0.5, 0.49999999999999967, 0.5, 0.5000000000000003, 0.49999999999999967, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.49999999999999445, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5000000000000056, 1.0])), (([1.0, 0.4999999999999998, 0.4999999999999999, 0.5, 0.5000000000000003, 0.5, 0.49999999999999967, 0.5000000000000003, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5000000000000056, 0.5, 0.5, 0.5, 0.5, 0.5, 0.49999999999999445], [0.5000000000000002, 0.5000000000000001, 0.5, 0.49999999999999967, 0.5, 0.5000000000000003, 0.49999999999999967, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.49999999999999445, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5000000000000056, 1.0]), ([1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5  …  0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0])))), boundary_conditions = ((PeriodicBC(), PeriodicBC()), (PeriodicBC(), PeriodicBC())), Re = 500.0, bodyforce = nothing, issteadybodyforce = false, closure_model = nothing, backend = CPU(false), workgroupsize = 64, temperature = nothing), IncompressibleNavierStokes.var"#psolve!#124"{FFTW.rFFTWPlan{Float64, -1, false, 2, Tuple{Int64, Int64}}, Matrix{Float64}, Matrix{ComplexF64}, Int64, CartesianIndices{2, Tuple{UnitRange{Int64}, UnitRange{Int64}}}}(FFTW real-to-complex plan for 100×100 array of Float64
(rdft2-rank>=2/1
  (rdft2-vrank>=1-x100/1
    (rdft2-ct-dit/20
      (hc2c-direct-20/76/0 "hc2cfdftv_20_avx2"
        (rdft2-ct-dit/2
          (hc2c-direct-2/4/0 "hc2cfdftv_2_avx2"
            (rdft2-r2hc-direct-2 "r2cf_2")
            (rdft2-r2hc01-direct-2 "r2cfII_2"))
          (dft-direct-10 "n1fv_10_avx2_128"))
        (rdft2-nop))
      (dft-direct-5-x10 "n1fv_5_avx2_128")))
  (dft-vrank>=1-x51/1
    (dft-ct-dit/10
      (dftw-direct-10/6 "t3fv_10_avx2_128")
      (dft-directbuf/14-10-x10 "n1fv_10_avx2")))), [4.860700224e-315 4.86064399e-315 … 3.38976095632428e-309 3.932991878809634e-309; 2.121995791e-314 8.4879831644e-314 … 3.39519326554915e-309 3.938424188034505e-309; … ; 4.86064375e-315 2.962171054e-315 … 3.92212726035989e-309 0.0; 0.0 0.0 … 3.927559569584763e-309 -8.4879832054e-314], ComplexF64[3.86311146e-315 + 5.141577226e-315im 9.14501123e-316 + 0.0im … 1.24823282e-315 + 0.0im 1.23464483e-315 + 4.67e-321im; 3.86311146e-315 + 5.141577226e-315im 3.79660967e-315 + 0.0im … 1.443e-321 + 3.79700093e-315im 9.58637628e-315 + 9.93659238e-315im; … ; 4.243991582e-314 + 0.0im 1.2731974746e-313 + 0.0im … 1.966162503567101e-236 + 1.24823282e-315im 1.534281654e-315 + 1.53427881e-315im; 0.0 + 0.0im 0.0 + 0.0im … 0.0 + 1.443e-321im 1.534214066e-315 + 1.53434284e-315im], Core.Box(([0.0, 0.003946543143456876, 0.01577059737104434, 0.035425498542622634, 0.06283367774273778, 0.09788696740969285, 0.1404470282234972, 0.19034589506796099, 0.24738663991227286, 0.3113441489959698  …  3.68865585100403, 3.752613360087727, 3.809654104932039, 3.859552971776503, 3.9021130325903073, 3.9371663222572626, 3.9645745014573777, 3.984229402628956, 3.996053456856544, 4.0], [0.0, 0.003946543143456876, 0.01577059737104434, 0.035425498542622634, 0.06283367774273778, 0.09788696740969285, 0.1404470282234972, 0.19034589506796099, 0.24738663991227286, 0.3113441489959698  …  0.3819660112501049, 0.3113441489959696, 0.24738663991227255, 0.19034589506796068, 0.14044702822349744, 0.09788696740969302, 0.06283367774273789, 0.0354254985426227, 0.015770597371044366, 0.003946543143456883])), 2, CartesianIndices((2:101, 2:101)))])

Now, we call the autodiff function from Enzyme:

julia
Enzyme.autodiff(Enzyme.Reverse, f!, Duplicated(dudt,ddudt), Duplicated(u,du), Const(params_ref), Const(t))
((nothing, nothing, nothing, nothing),)

Since we have passed a Duplicated object, the gradient of u is stored in du.

Finally, we can also compare its value with the one obtained by Zygote differentiating the out-of-place (non-mutating) version of the right-hand side:

julia
f = create_right_hand_side(setup, psolver)
_, zpull = Zygote.pullback(f, u, nothing, 0.0);
@assert zpull(dudt)[1] == du