Computing Partial Derivatives of Explicit Components Using JAX#

One of the barriers to using OpenMDAO is that to truly take advantage of OpenMDAO, the user needs to write code for the analytic partial derivatives of their Components. To avoid that, users can use the optional third-party JAX library, which can automatically differentiate native Python and NumPy functions.

This notebook gives an example of using JAX to do automatic differentiation (AD) for the Sellar example. Only forward mode AD will be used in this example, however other options are:

Forward-mode is better for “tall” Jacobian matrices (more outputs than inputs) whereas reverse-mode is better for “wide” Jacobian matrices (more inputs than outputs).

This notebook also shows how to use JAX’s just-in-time (jit) compiling capabilities to dramatically speed up computations.

The use of JAX is optional for OpenMDAO so if not already installed, the user needs to install it by issuing one of the following commands at your operating system command prompt:

pip install jax jaxlib
pip install openmdao[jax]
pip install openmdao[all]
!pip install jax
Requirement already satisfied: jax in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (0.5.0)
Requirement already satisfied: jaxlib<=0.5.0,>=0.5.0 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.0)
Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.1)
Requirement already satisfied: numpy>=1.25 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.26.4)
Requirement already satisfied: opt_einsum in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.11.1 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.13.1)

Here are some standard OpenMDAO imports

from functools import partial

import numpy as np

import openmdao.api as om

The JAX library includes a NumPy-like API, jax.numpy, which implements the NumPy API using the primitives in JAX. Almost anything that can be done with NumPy can be done with jax.numpy. JAX arrays are similar to NumPy arrays, but they are designed to work with accelerators such as GPUs and TPU.

To use jax.numpy, it needs to be imported, using the commonly used jnp abbreviation.

import jax
import jax.numpy as jnp

The default for JAX is to do single precision computations. For this example, we want to use double precision, so this line of code is needed.

jax.config.update("jax_enable_x64", True)

Here is one of the ExplicitComponents in the model where derivatives will be computed using JAX. Comments interspersed in the code provide some explanations and guidance. Here is an overview of the steps that need to be taken to make use of JAX for your ExplicitComponent.

NOTE: A newer, experimental API for using JAX with ExplicitComponents has been developed that simplifies the process outlined below. Check it out here

  1. Write a method to compute the outputs from the inputs. Borrowing from AD terminology, a suggested name for this method is _compute_primal. This method is the same as what you would normally write for the compute method of an ExplicitComponent, but it takes as its arguments the actual individual input variables rather than a dictionary of the inputs. This allows us to use JAX’s AD capabilities on this method. The _compute_primal method simply returns the outputs as a single value or as a tuple if there are more than one outputs. Apply the jit decorator to this method to speed up the computations.

  2. In the constructor of the ExplicitComponent, create an attribute and assign to it a function that will compute the partial derivatives of the ExplicitComponent. This simply makes use of the JAX function, jacfwd, applied to the _compute_primal method.

  3. Create a method that computes the partials. In this example, it is called _compute_partials_jacfwd but could be any name that makes sense to the user. Apply the jit decorator to this method to speed up the computations.

  4. Make use of the _compute_primal method in the usual OpenMDAO ExplicitComponent.compute method.

  5. Make use of the _compute_partials_jacfwd in the usual OpenMDAO ExplicitComponent.compute_partials method.

class SellarDis1(om.ExplicitComponent):
    def __init__(self):
        super().__init__()

        # argnums specifies which positional arguments to differentiate with respect to.
        # Here we want derivates with respect to all 3 inputs of _compute_primal.
        self.deriv_func_jacfwd = jax.jacfwd(self._compute_primal, argnums=[0, 1, 2])

    def initialize(self):
        # Added this option to this model to demonstrate how having options
        # requires special care when using jit. See comments below
        self.options.declare('scaling_ref', types=(float,), default=0.1)

    def setup(self):
        ref = self.options['scaling_ref']

        # Global Design Variable
        self.add_input('z', val=np.zeros(2))

        # Local Design Variable
        self.add_input('x', val=0.)

        # Coupling parameter
        self.add_input('y2', val=1.0)

        # Coupling output
        self.add_output('y1', val=1.0, lower=0.1, upper=1000., ref=ref)

    def setup_partials(self):
        # Finite difference everything
        self.declare_partials('*', '*')

    # The "partial" decorator returns a new function that has the same body as the original
    # function, but with the specified arguments bound.
    # Need to tell jit that self is a "static" argument.
    # This allows the jitted class to access the options attribute
    @partial(jax.jit, static_argnums=(0,))
    def _compute_primal(self, z, x, y2):
        return z[0]**2 + z[1] + x - 0.2*y2

    @partial(jax.jit, static_argnums=(0,))
    def _compute_partials_jacfwd(self, z, x, y2):
        # Always returns a tuple
        dz, dx, dy2 = self.deriv_func_jacfwd(z, x, y2)
        return dz, dx, dy2

    def compute(self, inputs, outputs):
        outputs['y1'] = self._compute_primal(*inputs.values())

    def compute_partials(self, inputs, partials):
        dz, dx, dy2 = self._compute_partials_jacfwd(*inputs.values())

        partials['y1', 'z'] = dz
        partials['y1', 'x'] = dx
        partials['y1', 'y2'] = dy2

Similarly, the second Sellar ExplicitComponent should be written in the same way.

class SellarDis2(om.ExplicitComponent):
    def __init__(self,):
        super().__init__()
        self.deriv_func_jacfwd = jax.jacfwd(self._compute_primal, argnums=[0, 1])

    def setup(self):
        # Global Design Variable
        self.add_input('z', val=jnp.zeros(2))

        # Coupling parameter
        self.add_input('y1', val=1.0)

        # Coupling output
        self.add_output('y2', val=1.0, lower=0.1, upper=1000., ref=1.0)

    def setup_partials(self):
        self.declare_partials('*', '*')

    @partial(jax.jit, static_argnums=(0,) )
    def _compute_primal(self, z, y1):
        # Depending on whether this is called via compute or compute_partials, y1 could have
        # different dimensions. It's just a scalar though
        if np.ndim(y1) == 1:
            y1 = y1[0]

        # if y1.real < 0.0:
        #     y1 *= -1
        # Because of jit, conditionals cannot be used as is, as in the above two lines of code.
        # Fortunately, JAX provides control flow primitives to deal with that.
        # For if statements, JAX provided the cond function.
        # See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit
        # for more information about control flow when using jit
        y1 = jax.lax.cond(y1.real < 0.0, lambda y1 : -y1, lambda y1 : y1, y1)

        return y1**.5 + z[0] + z[1]

    @partial(jax.jit, static_argnums=(0,))
    def _compute_partials_jacfwd(self, z, y1):
        dz, dy1 = self.deriv_func_jacfwd(z, y1)
        return dz, dy1

    def compute(self, inputs, outputs):
        outputs['y2'] = self._compute_primal(*inputs.values())

    def compute_partials(self, inputs, partials):
        # pass in y1, which is used in a conditional, as a scalar, which is hashable
        z, y1 = inputs.values()
        y1 = y1[0]
        dz, dy1 = self._compute_partials_jacfwd(z, y1)

        partials['y2', 'z'] = dz
        partials['y2', 'y1'] = dy1

The rest of this code is standard OpenMDAO code. The code can be run as normal.

class SellarDerivatives(om.Group):
    """
    Group containing the Sellar MDA. This version uses the disciplines with derivatives.
    """

    def setup(self):
        self.add_subsystem('d1', SellarDis1(), promotes=['x', 'z', 'y1', 'y2'])
        self.add_subsystem('d2', SellarDis2(), promotes=['z', 'y1', 'y2'])

        obj = self.add_subsystem('obj_cmp', om.ExecComp('obj = x**2 + z[1] + y1 + exp(-y2)', obj=0.0,
                                                  x=0.0, z=np.array([0.0, 0.0]), y1=0.0, y2=0.0),
                           promotes=['obj', 'x', 'z', 'y1', 'y2'])

        con1 = self.add_subsystem('con_cmp1', om.ExecComp('con1 = 3.16 - y1', con1=0.0, y1=0.0),
                           promotes=['con1', 'y1'])
        con2 = self.add_subsystem('con_cmp2', om.ExecComp('con2 = y2 - 24.0', con2=0.0, y2=0.0),
                           promotes=['con2', 'y2'])

        # manually declare partials to allow graceful fallback to FD when nested under a higher
        # level complex step approximation.
        obj.declare_partials(of='*', wrt='*', method='cs')
        con1.declare_partials(of='*', wrt='*', method='cs')
        con2.declare_partials(of='*', wrt='*', method='cs')

        self.set_input_defaults('x', 1.0)
        self.set_input_defaults('z', np.array([5.0, 2.0]))


prob = om.Problem()
prob.model = model = SellarDerivatives()

model.add_design_var('z', lower=np.array([-10.0, 0.0]), upper=np.array([10.0, 10.0]))
model.add_design_var('x', lower=0.0, upper=10.0)
model.add_objective('obj')
model.add_constraint('con1', upper=0.0)
model.add_constraint('con2', upper=0.0)
model.add_constraint('x', upper=11.0, linear=True)

prob.set_solver_print(level=0)

prob.driver = om.ScipyOptimizeDriver(optimizer='SLSQP', tol=1e-9, disp=False)

prob.setup(check=False, mode='fwd')

prob.run_driver()
print(prob.get_val('obj'))
[3.18339401]