Computing Partial Derivatives using JaxImplicitComponent#
This notebook gives an example of using jax to do automatic differentiation (AD) for a linear system component.
The example contains a JaxImplicitComponent
called LinSysComp
. A static option called ‘adder’ has been added to LinSysComp
in order to demonstrate how to handle what we call
‘self statics’ in a jax component. Comments interspersed in the code provide some explanations and guidance.
Here is an overview of the steps that need to be taken to make use of jax for your JaxImplicitComponent
.
Inherit your component from
JaxImplicitComponent
.Write a method named
compute_primal
to compute the residuals from the inputs and outputs. This method is the same as what you would normally write for theapply_nonlinear
method of anImplicitComponent
, but it takes as its arguments the actual individual input and output variables rather than a dictionary of the inputs and outputs, and returns the residuals as a tuple. This allows us to use JAX’s AD capabilities on this method. Ordering of the inputs and outputs is critical. The order of the inputs and outputs passed into the method and the residuals returned from the method must match the order that they are declared in the component. Note that all of the inputs are passed first, followed by the outputs. Also, discrete inputs, if any, are passed individually as arguments after the output variables.By default your component will jit the
compute_primal
method. If for some reason you don’t want this, then you can setself.options['use_jit']
to False. This can be useful when debugging as it allows you to put print statements inside of yourcompute_primal
method.For a typical component, that’s it. You can skip step 5.
If your
compute_primal
depends on variables that are ‘static’ according to jax, i.e., they affect the output of your compute_primal but are not passed in as arguments, you’ll need to define aget_self_statics
method on your component that returns a tuple containing all such variables. The returned tuple must be hashable. If these static values ever change, jax will recompile thecompute_primal
function. InLinSysComp
below, there is one static option variable,self.options['adder']
.
Linear System Component Example#
The following component is a JaxImplicitComponent representing a linear system A(x+adder) = b
.
import openmdao.api as om
import numpy as np
import jax.numpy as jnp # noqa
class LinSysComp(om.JaxImplicitComponent):
def initialize(self):
self.options.declare('size', default=1, types=int)
# adder is a 'static' value that is constant during derivative computation, but it
# can be changed by the user between runs, so any jitted jax functions need to be re-jitted
# if it changes
self.options.declare('adder', default=0.0, types=float)
def setup(self):
size = self.options['size']
shape = (size, )
self.add_input("A", val=np.eye(size))
self.add_input("b", val=np.ones(shape))
self.add_output("x", shape=shape, val=.1)
def setup_partials(self):
# Because this is an ImplicitComponent, we have to define a linear and nonlinear solver
if self.matrix_free:
# if we're runnning in matrix_free mode, don't use a direct solver
self.linear_solver = om.ScipyKrylov()
else:
self.linear_solver = om.DirectSolver()
self.nonlinear_solver = om.NewtonSolver(solve_subsystems=False)
# we have a static variable, self.options['adder'] that we use in compute_primal, so we need to
# define the get_self_statics method to tell OpenMDAO about it. Also note that the return value
# must be a tuple, so we follow the value with a comma since we have only one value.
def get_self_statics(self):
return (self.options['adder'], )
# compute_primal replaces the apply_nonlinear method
def compute_primal(self, A, b, x):
return A.dot(x + self.options['adder']) - b
The rest of this code is standard OpenMDAO code. The code can be run as normal.
A = np.array([[1., 1., 1.], [1., 2., 3.], [0., 1., 3.]])
b = np.array([1, 2, -3])
prob = om.Problem()
ivc = prob.model.add_subsystem('ivc', om.IndepVarComp())
ivc.add_output('A', A)
ivc.add_output('b', b)
lingrp = prob.model.add_subsystem('lingrp', om.Group())
lin = lingrp.add_subsystem('lin', LinSysComp(size=b.size))
prob.model.connect('ivc.A', 'lingrp.lin.A')
prob.model.connect('ivc.b', 'lingrp.lin.b')
prob.setup()
prob.set_solver_print(level=0)
prob.set_val('ivc.A', A)
prob.set_val('ivc.b', b)
prob.run_model()
print(prob['lingrp.lin.x'])
# changing the adder value should change the resulting x value
lin.options['adder'] = 1.0
prob.run_model()
print(prob['lingrp.lin.x'])
[-4. 9. -4.]
[-5. 8. -5.]