Computing Partial Derivatives using JaxImplicitComponent#

This notebook gives an example of using jax to do automatic differentiation (AD) for a linear system component. The example contains a JaxImplicitComponent called LinSysComp. A static option called ‘adder’ has been added to LinSysComp in order to demonstrate how to handle what we call ‘self statics’ in a jax component. Comments interspersed in the code provide some explanations and guidance.

Here is an overview of the steps that need to be taken to make use of jax for your JaxImplicitComponent.

  1. Inherit your component from JaxImplicitComponent.

  2. Write a method named compute_primal to compute the residuals from the inputs and outputs. This method is the same as what you would normally write for the apply_nonlinear method of an ImplicitComponent, but it takes as its arguments the actual individual input and output variables rather than a dictionary of the inputs and outputs, and returns the residuals as a tuple. This allows us to use JAX’s AD capabilities on this method. Ordering of the inputs and outputs is critical. The order of the inputs and outputs passed into the method and the residuals returned from the method must match the order that they are declared in the component. Note that all of the inputs are passed first, followed by the outputs. Also, discrete inputs, if any, are passed individually as arguments after the output variables.

  3. By default your component will jit the compute_primal method. If for some reason you don’t want this, then you can set self.options['use_jit'] to False. This can be useful when debugging as it allows you to put print statements inside of your compute_primal method.

  4. For a typical component, that’s it. You can skip step 5.

  5. If your compute_primal depends on variables that are ‘static’ according to jax, i.e., they affect the output of your compute_primal but are not passed in as arguments, you’ll need to define a get_self_statics method on your component that returns a tuple containing all such variables. The returned tuple must be hashable. If these static values ever change, jax will recompile the compute_primal function. In LinSysComp below, there is one static option variable, self.options['adder'].

Linear System Component Example#

The following component is a JaxImplicitComponent representing a linear system A(x+adder) = b.

import openmdao.api as om
import numpy as np
import jax.numpy as jnp  # noqa


class LinSysComp(om.JaxImplicitComponent):
    def initialize(self):
        self.options.declare('size', default=1, types=int)

        # adder is a 'static' value that is constant during derivative computation, but it
        # can be changed by the user between runs, so any jitted jax functions need to be re-jitted
        # if it changes
        self.options.declare('adder', default=0.0, types=float)

    def setup(self):
        size = self.options['size']

        shape = (size, )

        self.add_input("A", val=np.eye(size))
        self.add_input("b", val=np.ones(shape))
        self.add_output("x", shape=shape, val=.1)

    def setup_partials(self):
        # Because this is an ImplicitComponent, we have to define a linear and nonlinear solver
        if self.matrix_free:
            # if we're runnning in matrix_free mode, don't use a direct solver
            self.linear_solver = om.ScipyKrylov()
        else:
            self.linear_solver = om.DirectSolver()
        self.nonlinear_solver = om.NewtonSolver(solve_subsystems=False)

    # we have a static variable, self.options['adder'] that we use in compute_primal, so we need to
    # define the get_self_statics method to tell OpenMDAO about it.  Also note that the return value
    # must be a tuple, so we follow the value with a comma since we have only one value.
    def get_self_statics(self):
        return (self.options['adder'], )

    # compute_primal replaces the apply_nonlinear method
    def compute_primal(self, A, b, x):
        return A.dot(x + self.options['adder']) - b

The rest of this code is standard OpenMDAO code. The code can be run as normal.

A = np.array([[1., 1., 1.], [1., 2., 3.], [0., 1., 3.]])
b = np.array([1, 2, -3])

prob = om.Problem()

ivc = prob.model.add_subsystem('ivc', om.IndepVarComp())
ivc.add_output('A', A)
ivc.add_output('b', b)

lingrp = prob.model.add_subsystem('lingrp', om.Group())
lin = lingrp.add_subsystem('lin', LinSysComp(size=b.size))

prob.model.connect('ivc.A', 'lingrp.lin.A')
prob.model.connect('ivc.b', 'lingrp.lin.b')

prob.setup()
prob.set_solver_print(level=0)

prob.set_val('ivc.A', A)
prob.set_val('ivc.b', b)

prob.run_model()

print(prob['lingrp.lin.x'])

# changing the adder value should change the resulting x value
lin.options['adder'] = 1.0

prob.run_model()

print(prob['lingrp.lin.x'])
[-4.  9. -4.]
[-5.  8. -5.]