Computing Partial Derivatives using JaxImplicitComponent#

One of the barriers to using OpenMDAO is that to truly take advantage of OpenMDAO, the user needs to write code for the analytic partial derivatives of their Components. To avoid that, users can use the optional third-party JAX library, which can automatically differentiate native Python and NumPy functions.

This notebook gives an example of using JAX to do automatic differentiation (AD) for a simple example. Based on the sizes of inputs vs. outputs for a component, the JaxImplicitComponent will use either forward mode (for more outputs than inputs) or reverse mode (more inputs than outputs) to compute the partial jacobian. The choice of forward or reverse mode will be done automatically. If a JaxImplicitComponent’s ‘matrix_free’ attribute is set, then the component will use the jax.jvp method in forward mode and jax_vjp in reverse mode.

The JaxImplicitComponent will also use JAX’s just-in-time (jit) compiling capabilities by default to dramatically speed up computations. jit can be disabled by setting self.options[‘use_jit’] to False.

The use of JAX is optional for OpenMDAO so if not already installed, the user needs to install it. See the installation instructions for more information about installing JAX.

!pip install jax
Requirement already satisfied: jax in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (0.4.35)
Requirement already satisfied: jaxlib<=0.4.35,>=0.4.34 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.4.35)
Requirement already satisfied: ml-dtypes>=0.4.0 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.0)
Requirement already satisfied: numpy>=1.24 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.26.4)
Requirement already satisfied: opt-einsum in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.10 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.13.1)

Here are some standard OpenMDAO imports

import numpy as np
import openmdao.api as om

The JAX library includes a NumPy-like API, jax.numpy, which implements the NumPy API using the primitives in JAX. Almost anything that can be done with NumPy can be done with jax.numpy. JAX arrays are similar to NumPy arrays, but they are designed to work with accelerators such as GPUs and TPUs.

To use jax.numpy, it needs to be imported, using the commonly used jnp abbreviation.

import jax # noqa
import jax.numpy as jnp  # noqa

Here is one of the JaxImplicitComponents in the model where derivatives will be computed using JAX. Comments interspersed in the code provide some explanations and guidance. Here is an overview of the steps that need to be taken to make use of JAX for your JaxImplicitComponent.

  1. Inherit your component from JaxImplicitComponent.

  2. Write a method named compute_primal to compute the residuals from the inputs and outputs. This method is the same as what you would normally write for the apply_nonlinear method of an ImplicitComponent, but it takes as its arguments the actual individual input and output variables rather than a dictionary of the inputs and outputs, and returns the residuals as a tuple. This allows us to use JAX’s AD capabilities on this method. Ordering of the inputs and outputs is critical. The order of the inputs and outputs passed into the method and the residuals returned from the method must match the order that they are declared in the component. Also, discrete inputs, if any, are passed individually as arguments after the continuous variables.

  3. By default your component will jit the compute_primal method. If for some reason you don’t want this, then you can set self.options[‘use_jit’] to False.

  4. For a typical component, that’s it. You can skip step 5.

  5. However, if your compute_primal depends on variables that are ‘static’, i.e., they don’t change during computation of derivatives, you’ll need to define a get_self_statics method on your component that returns a tuple containing all of the static variables that your compute_primal method depends on, excluding any discrete input variables. The returned tuple must be hashable. If these static values ever change, jax will recompile the compute_primal function. In LinSysComp below, there is one static option variable, self.options['adder'].

class LinSysComp(om.JaxImplicitComponent):
    def initialize(self):
        self.options.declare('size', default=1, types=int)

        # adder is a 'static' value that is constant during derivative computation, but it
        # can be changed by the user between runs, so any jitted jax functions will be re-jitted
        # if it changes
        self.options.declare('adder', default=0., types=float)

    def setup(self):
        size = self.options['size']

        shape = (size, )

        self.add_input("A", val=np.eye(size))
        self.add_input("b", val=np.ones(shape))
        self.add_output("x", shape=shape, val=.1)

    def setup_partials(self):
        size = self.options['size']
        mat_size = size * size
        full_size = size

        row_col = np.arange(full_size, dtype="int")
        self.declare_partials('x', 'b', val=np.full(full_size, -1.0), rows=row_col, cols=row_col)

        rows = np.repeat(np.arange(full_size), size)
        cols = np.arange(mat_size)
        self.declare_partials('x', 'A', rows=rows, cols=cols)

        cols = np.tile(np.arange(size), size)
        cols += np.repeat(np.arange(1), mat_size) * size
        self.declare_partials(of='x', wrt='x', rows=rows, cols=cols)

        # if we're runnning in matrix_free mode, we don't want to use a direct solver
        if self.matrix_free:
            self.linear_solver = om.ScipyKrylov()
        else:
            self.linear_solver = om.DirectSolver()
        self.nonlinear_solver = om.NewtonSolver(solve_subsystems=False)

    # we have a static variable, self.options['adder'] that we use in compute_primal, so we need to
    # define theh get_self_statics method to tell OpenMDAO about it
    def get_self_statics(self):
        return (self.options['adder'], )

    # compute_primal replaces the old apply_nonlinear method
    def compute_primal(self, A, b, x):
        return A.dot(x + self.options['adder']) - b

The rest of this code is standard OpenMDAO code. The code can be run as normal.

A = np.array([[1., 1., 1.], [1., 2., 3.], [0., 1., 3.]])
b = np.array([1, 2, -3])

prob = om.Problem()

ivc = prob.model.add_subsystem('ivc', om.IndepVarComp())
ivc.add_output('A', A)
ivc.add_output('b', b)

lingrp = prob.model.add_subsystem('lingrp', om.Group())
lin = lingrp.add_subsystem('lin', LinSysComp(size=b.size))

prob.model.connect('ivc.A', 'lingrp.lin.A')
prob.model.connect('ivc.b', 'lingrp.lin.b')

prob.setup()
prob.set_solver_print(level=0)

prob.set_val('ivc.A', A)
prob.set_val('ivc.b', b)

prob.run_model()

print(prob['lingrp.lin.x'])

# changing the adder value should change the resulting x value
lin.options['adder'] = 1.0

prob.run_model()

print(prob['lingrp.lin.x'])
[-4.  9. -4.]
[-5.  8. -5.]