Differentiating through the code
IncompressibleNavierStokes is reverse-mode differentiable, which means that you can back-propagate gradients through the code. Two AD libraries are currently supported:
Zygote.jl: it is the default AD library in the Julia ecosystem and is the most widely used.
Enzyme.jl: currently has low coverage over the Julia programming language, however it is usually the most efficient if applicable.
Automatic differentiation with Zygote
Zygote.jl is the default choice for AD backend because it is easy to understand, compatible with most of the Julia ecosystem and good with vectorized code and BLAS. This comes at a cost however, as intermediate velocity fields need to be stored in memory for use in the backward pass. For this reason, many of the operators come in two versions: a slow differentiable allocating non-mutating variant (e.g. divergence
) and fast non-differentiable non-allocating mutating variant (e.g. divergence!
.)
Zygote limitation: array mutation
To make your code differentiable, you must use the differentiable versions of the operators (without the exclamation marks).
Example: Gradient of kinetic energy
To differentiate outputs of a simulation with respect to the initial conditions, make a time stepping loop composed of differentiable operations:
using IncompressibleNavierStokes
ax = range(0, 1, 101)
setup = Setup(; x = (ax, ax), Re = 500.0)
psolver = default_psolver(setup)
method = RKMethods.RK44P2()
Δt = 0.01
nstep = 100
(; Iu) = setup.grid
function final_energy(u)
stepper = create_stepper(method; setup, psolver, u, temp = nothing, t = 0.0)
for it = 1:nstep
stepper = timestep(method, stepper, Δt)
end
(; u) = stepper
sum(abs2, u[Iu[1], 1]) / 2 + sum(abs2, u[Iu[2], 2]) / 2
end
u = random_field(setup)
using Zygote
g, = Zygote.gradient(final_energy, u)
@show size(u) size(g)
(102, 102, 2)
Now g
is the gradient of final_energy
with respect to the initial conditions u
, and consequently has the same size.
Note that every operation in the final_energy
function is non-mutating and thus differentiable.
Automatic differentiation with Enzyme
Enzyme.jl is highly-efficient and its ability to perform AD on optimized code allows Enzyme to meet or exceed the performance of state-of-the-art AD tools. The downside is that restricts the user's defined f function to not do things like require garbage collection or calls to BLAS/LAPACK. However, mutation is supported, meaning that in-place f with fully mutating non-allocating code will work with Enzyme and this will be the most efficient adjoint implementation.
Enzyme limitation: vector returns
Enzyme's autodiff function can only handle functions with scalar output. To implement pullbacks for array-valued functions, use a mutating function that returns nothing
and stores its result in one of the arguments, which must be passed wrapped in a Duplicated. In IncompressibleNavierStokes, we provide enzyme_wrapper
to automatically wrap the function and its arguments in the correct way.
Example: Gradient of the right-hand side
In this example we differentiate the right-hand side of the Navier-Stokes equations with respect to the velocity field u
:
using Enzyme
ax = range(0, 1, 101)
setup = Setup(; x = (ax, ax), Re = 500.0)
psolver = default_psolver(setup)
u = random_field(setup)
dudt = similar(u)
t = 0.0
f! = right_hand_side!
right_hand_side! (generic function with 1 method)
Notice that we are using the mutating (in-place) version of the right-hand side function. This function can not be differentiate by Zygote, which requires the slower non-mutating version of the right-hand side.
We then define the Dual
part of the input and output, required to store the adjoint values:
ddudt = Enzyme.make_zero(dudt) .+ 1;
du = Enzyme.make_zero(u);
102×102×2 Array{Float64, 3}:
[:, :, 1] =
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
⋮ ⋮ ⋱ ⋮ ⋮
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
[:, :, 2] =
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
⋮ ⋮ ⋱ ⋮ ⋮
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 … 0.0 0.0 0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
Remember that the derivative of the output (also called the seed) has to be set to dudt
inside right_hand_side!
.
Then we pack the parameters to be passed to right_hand_side!
:
params = [setup, psolver];
params_ref = Ref(params);
Base.RefValue{Vector{Any}}(Any[(grid = (xlims = ((0.0, 1.0), (0.0, 1.0)), dimension = IncompressibleNavierStokes.Dimension{2}(), N = (102, 102), Nu = ((100, 100), (100, 100)), Np = (100, 100), Iu = (CartesianIndices((2:101, 2:101)), CartesianIndices((2:101, 2:101))), Ip = CartesianIndices((2:101, 2:101)), x = ([-0.010000000000000009, 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08 … 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0, 1.01], [-0.010000000000000009, 0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08 … 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0, 1.01]), xu = (([0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09 … 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0, 1.01], [-0.0050000000000000044, 0.005, 0.015, 0.025, 0.035, 0.045, 0.055, 0.065, 0.07500000000000001, 0.08499999999999999 … 0.915, 0.925, 0.935, 0.945, 0.955, 0.965, 0.975, 0.985, 0.995, 1.005]), ([-0.0050000000000000044, 0.005, 0.015, 0.025, 0.035, 0.045, 0.055, 0.065, 0.07500000000000001, 0.08499999999999999 … 0.915, 0.925, 0.935, 0.945, 0.955, 0.965, 0.975, 0.985, 0.995, 1.005], [0.0, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09 … 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0, 1.01])), xp = ([-0.0050000000000000044, 0.005, 0.015, 0.025, 0.035, 0.045, 0.055, 0.065, 0.07500000000000001, 0.08499999999999999 … 0.915, 0.925, 0.935, 0.945, 0.955, 0.965, 0.975, 0.985, 0.995, 1.005], [-0.0050000000000000044, 0.005, 0.015, 0.025, 0.035, 0.045, 0.055, 0.065, 0.07500000000000001, 0.08499999999999999 … 0.915, 0.925, 0.935, 0.945, 0.955, 0.965, 0.975, 0.985, 0.995, 1.005]), Δ = ([0.010000000000000009, 0.01, 0.01, 0.009999999999999998, 0.010000000000000002, 0.010000000000000002, 0.009999999999999995, 0.010000000000000009, 0.009999999999999995, 0.009999999999999995 … 0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009], [0.010000000000000009, 0.01, 0.01, 0.009999999999999998, 0.010000000000000002, 0.010000000000000002, 0.009999999999999995, 0.010000000000000009, 0.009999999999999995, 0.009999999999999995 … 0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009]), Δu = ([0.010000000000000005, 0.009999999999999998, 0.010000000000000002, 0.010000000000000002, 0.009999999999999995, 0.010000000000000002, 0.010000000000000002, 0.010000000000000009, 0.009999999999999981, 0.010000000000000009 … 0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.0050000000000000044], [0.010000000000000005, 0.009999999999999998, 0.010000000000000002, 0.010000000000000002, 0.009999999999999995, 0.010000000000000002, 0.010000000000000002, 0.010000000000000009, 0.009999999999999981, 0.010000000000000009 … 0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.010000000000000009, 0.009999999999999898, 0.0050000000000000044]), A = ((([1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5 … 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5 … 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0]), ([1.0, 0.4999999999999998, 0.4999999999999999, 0.5, 0.5000000000000003, 0.5, 0.49999999999999967, 0.5000000000000003, 0.5, 0.5 … 0.5, 0.5, 0.5, 0.5000000000000056, 0.5, 0.5, 0.5, 0.5, 0.5, 0.49999999999999445], [0.5000000000000002, 0.5000000000000001, 0.5, 0.49999999999999967, 0.5, 0.5000000000000003, 0.49999999999999967, 0.5, 0.5, 0.5 … 0.5, 0.5, 0.49999999999999445, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5000000000000056, 1.0])), (([1.0, 0.4999999999999998, 0.4999999999999999, 0.5, 0.5000000000000003, 0.5, 0.49999999999999967, 0.5000000000000003, 0.5, 0.5 … 0.5, 0.5, 0.5, 0.5000000000000056, 0.5, 0.5, 0.5, 0.5, 0.5, 0.49999999999999445], [0.5000000000000002, 0.5000000000000001, 0.5, 0.49999999999999967, 0.5, 0.5000000000000003, 0.49999999999999967, 0.5, 0.5, 0.5 … 0.5, 0.5, 0.49999999999999445, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5000000000000056, 1.0]), ([1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5 … 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5 … 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 1.0])))), boundary_conditions = ((PeriodicBC(), PeriodicBC()), (PeriodicBC(), PeriodicBC())), Re = 500.0, bodyforce = nothing, issteadybodyforce = false, closure_model = nothing, backend = CPU(false), workgroupsize = 64, temperature = nothing), IncompressibleNavierStokes.var"#psolve!#124"{FFTW.rFFTWPlan{Float64, -1, false, 2, Tuple{Int64, Int64}}, Matrix{Float64}, Matrix{ComplexF64}, Int64, CartesianIndices{2, Tuple{UnitRange{Int64}, UnitRange{Int64}}}}(FFTW real-to-complex plan for 100×100 array of Float64
(rdft2-rank>=2/1
(rdft2-vrank>=1-x100/1
(rdft2-ct-dit/20
(hc2c-direct-20/76/0 "hc2cfdftv_20_avx2"
(rdft2-ct-dit/2
(hc2c-direct-2/4/0 "hc2cfdftv_2_avx2"
(rdft2-r2hc-direct-2 "r2cf_2")
(rdft2-r2hc01-direct-2 "r2cfII_2"))
(dft-direct-10 "n1fv_10_avx2_128"))
(rdft2-nop))
(dft-direct-5-x10 "n1fv_5_avx2_128")))
(dft-vrank>=1-x51/1
(dft-ct-dit/10
(dftw-direct-10/6 "t3fv_10_avx2_128")
(dft-directbuf/14-10-x10 "n1fv_10_avx2")))), [6.94801423711505e-310 6.94814356225947e-310 … 0.0 0.0; 6.9480071215168e-310 6.94814356225947e-310 … 0.0 0.0; … ; 6.9481408416929e-310 0.0 … 0.0 8.70018274495e-312; 6.9481498153622e-310 0.0 … 0.0 7.95748421815e-312], ComplexF64[4.649813086e-315 + 0.0im 4.64981709e-315 + 2.121995791e-314im … 0.0 + 0.0im 0.0 + 0.0im; 1.443e-321 + 4.649812147e-315im -9.605593679422648e85 + 1.63e-322im … 0.0 + 0.0im 1.443e-321 + 4.650211273e-315im; … ; 4.649815744e-315 + 0.0im 0.0 + 0.0im … 0.0 + 0.0im 6.94815806332016e-310 + 6.94815806332016e-310im; 4.64981598e-315 + 8.4879831644e-314im 0.0 + 0.0im … 0.0 + 0.0im 1.09968329e-315 + 3.75132531e-315im], Core.Box(([0.0, 0.003946543143456876, 0.01577059737104434, 0.035425498542622634, 0.06283367774273778, 0.09788696740969285, 0.1404470282234972, 0.19034589506796099, 0.24738663991227286, 0.3113441489959698 … 3.68865585100403, 3.752613360087727, 3.809654104932039, 3.859552971776503, 3.9021130325903073, 3.9371663222572626, 3.9645745014573777, 3.984229402628956, 3.996053456856544, 4.0], [0.0, 0.003946543143456876, 0.01577059737104434, 0.035425498542622634, 0.06283367774273778, 0.09788696740969285, 0.1404470282234972, 0.19034589506796099, 0.24738663991227286, 0.3113441489959698 … 0.3819660112501049, 0.3113441489959696, 0.24738663991227255, 0.19034589506796068, 0.14044702822349744, 0.09788696740969302, 0.06283367774273789, 0.0354254985426227, 0.015770597371044366, 0.003946543143456883])), 2, CartesianIndices((2:101, 2:101)))])
Now, we call the autodiff
function from Enzyme:
Enzyme.autodiff(Enzyme.Reverse, f!, Duplicated(dudt,ddudt), Duplicated(u,du), Const(params_ref), Const(t))
((nothing, nothing, nothing, nothing),)
Since we have passed a Duplicated
object, the gradient of u
is stored in du
.
Finally, we can also compare its value with the one obtained by Zygote differentiating the out-of-place (non-mutating) version of the right-hand side:
f = create_right_hand_side(setup, psolver)
_, zpull = Zygote.pullback(f, u, nothing, 0.0);
@assert zpull(dudt)[1] == du