Computing Partial Derivatives using JaxExplicitComponent

Computing Partial Derivatives using JaxExplicitComponent#

One of the barriers to using OpenMDAO is that to truly take advantage of OpenMDAO, the user needs to write code for the analytic partial derivatives of their Components. To avoid that, users can use the optional third-party JAX library, which can automatically differentiate native Python and NumPy functions.

This notebook gives an example of using JAX to do automatic differentiation (AD) for the Sellar example. Based on the sizes of inputs vs. outputs for a component, the JaxExplicitComponent will use either forward mode (for more outputs than inputs) or reverse mode (more inputs than outputs) to compute derivatives. The choice of forward or reverse mode will be done automatically. If a JaxExplicitComponent’s ‘matrix_free’ attribute is set, then the component will use the jax.jvp method in forward mode and jax_vjp in reverse mode.

The JaxExplicitComponent will also use JAX’s just-in-time (jit) compiling capabilities by default to speed up computations. jit can be disabled by setting self.options[‘use_jit’] to False.

The use of JAX is optional for OpenMDAO so if not already installed, the user needs to install it. See the installation instructions for more information about installing JAX.

!pip install jax
Requirement already satisfied: jax in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (0.4.35)
Requirement already satisfied: jaxlib<=0.4.35,>=0.4.34 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.4.35)
Requirement already satisfied: ml-dtypes>=0.4.0 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.0)
Requirement already satisfied: numpy>=1.24 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.26.4)
Requirement already satisfied: opt-einsum in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.10 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.13.1)

Here are some standard OpenMDAO imports

import numpy as np
import openmdao.api as om

The JAX library includes a NumPy-like API, jax.numpy, which implements the NumPy API using the primitives in JAX. Almost anything that can be done with NumPy can be done with jax.numpy. JAX arrays are similar to NumPy arrays, but they are designed to work with accelerators such as GPUs and TPUs.

To use jax.numpy, it needs to be imported, using the commonly used jnp abbreviation.

import jax
import jax.numpy as jnp

Here is one of the JaxExplicitComponents in the model where derivatives will be computed using JAX. Comments interspersed in the code provide some explanations and guidance. Here is an overview of the steps that need to be taken to make use of JAX for your JaxExplicitComponent.

  1. Inherit your component from JaxExplicitComponent.

  2. Write a method named compute_primal to compute the outputs from the inputs. This method is the same as what you would normally write for the compute method of an ExplicitComponent, but it takes as its arguments the actual individual input variables rather than a dictionary of the inputs, and returns the outputs as a tuple. This allows us to use JAX’s AD capabilities on this method. Ordering of the inputs and outputs is critical. The order of the inputs passed into the method and the outputs returned from the method must match the order that they are declared as inputs and outputs in the component. Also, discrete inputs, if any, are passed individually as arguments after the continuous variables.

  3. By default your component will jit the compute_primal method. If for some reason you don’t want this, then you can set self.options[‘use_jit’] to False.

  4. For a typical component like SellarDis2 below, that’s it. You can skip step 5.

  5. However, if your compute_primal depends on variables that are ‘static’, i.e., they don’t change during computation of derivatives, you’ll need to define a get_self_statics method on your component that returns a tuple containing all of the static variables that your compute_primal method depends on, excluding discrete input variables. The returned tuple must be hashable. If these static values ever change, jax will recompile the compute_primal function, assuming ‘use_jit’ is True. In SellarDis1 below, there is one static attribute, self.staticvar, and one static option variable, self.options['scaling_ref'].

class SellarDis1(om.JaxExplicitComponent):
    def initialize(self):
        # Added this option to this model to demonstrate how having options
        # requires special care when using jit. See comments below
        self.options.declare('static_opt', types=(float,), default=1.)
        self.staticvar = 1.

    def setup(self):
        # Global Design Variable
        self.add_input('z', val=jnp.zeros(2))

        # Local Design Variable
        self.add_input('x', val=0.)

        # Coupling parameter
        self.add_input('y2', val=1.0)

        # Coupling output
        self.add_output('y1', val=1.0, lower=0.1, upper=1000., ref=0.1)

    def setup_partials(self):
        self.declare_partials('*', '*')

    # because our compute primal output depends on static variables, in this case self.staticvar
    # and self.options['static_opt'], we must define a get_self_statics method. This method must
    # return a tuple of all static variables. Their order in the tuple doesn't matter.  If your
    # component happens to have discrete inputs, do NOT return them here. Discrete inputs are passed
    # into the compute_primal function individually, after the continuous variables.
    def get_self_statics(self):
        # return value must be hashable
        return self.staticvar, self.options['static_opt']

    def compute_primal(self, z, x, y2):
        return (z[0]**2 + z[1] + x - 0.2*y2*self.staticvar*self.options['static_opt'],)

The second Sellar JaxExplicitComponent should be written in a similar way.

class SellarDis2(om.JaxExplicitComponent):
    def setup(self):
        # Global Design Variable
        self.add_input('z', val=jnp.zeros(2))

        # Coupling parameter
        self.add_input('y1', val=1.0)

        # Coupling output
        self.add_output('y2', val=1.0, lower=0.1, upper=1000., ref=1.0)

    def setup_partials(self):
        self.declare_partials('*', '*')

    def compute_primal(self, z, y1):
        # if y1[0].real < 0.0:
        #     y1[0] *= -1
        # Because of jit, conditionals cannot be used as is, as in the above two lines of code.
        # Fortunately, JAX provides control flow primitives to deal with that.
        # For if statements, JAX provides the cond function.
        # See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit
        # for more information about control flow when using jit
        y1 = jax.lax.cond(y1[0].real < 0.0, lambda x : -x, lambda x : x, y1[0])

        return y1**.5 + z[0] + z[1]

The rest of this code is standard OpenMDAO code. The code can be run as normal.

class SellarDerivatives(om.Group):
    """
    Group containing the Sellar MDA. This version uses the disciplines with derivatives.
    """

    def setup(self):
        self.add_subsystem('d1', SellarDis1(), promotes=['x', 'z', 'y1', 'y2'])
        self.add_subsystem('d2', SellarDis2(), promotes=['z', 'y1', 'y2'])

        obj = self.add_subsystem('obj_cmp', om.ExecComp('obj = x**2 + z[1] + y1 + exp(-y2)', obj=0.0,
                                                  x=0.0, z=np.array([0.0, 0.0]), y1=0.0, y2=0.0),
                           promotes=['obj', 'x', 'z', 'y1', 'y2'])

        con1 = self.add_subsystem('con_cmp1', om.ExecComp('con1 = 3.16 - y1', con1=0.0, y1=0.0),
                           promotes=['con1', 'y1'])
        con2 = self.add_subsystem('con_cmp2', om.ExecComp('con2 = y2 - 24.0', con2=0.0, y2=0.0),
                           promotes=['con2', 'y2'])

        # manually declare partials to allow graceful fallback to FD when nested under a higher
        # level complex step approximation.
        obj.declare_partials(of='*', wrt='*', method='cs')
        con1.declare_partials(of='*', wrt='*', method='cs')
        con2.declare_partials(of='*', wrt='*', method='cs')

        self.set_input_defaults('x', 1.0)
        self.set_input_defaults('z', np.array([5.0, 2.0]))


prob = om.Problem()
prob.model = model = SellarDerivatives()

model.add_design_var('z', lower=np.array([-10.0, 0.0]), upper=np.array([10.0, 10.0]))
model.add_design_var('x', lower=0.0, upper=10.0)
model.add_objective('obj')
model.add_constraint('con1', upper=0.0)
model.add_constraint('con2', upper=0.0)
model.add_constraint('x', upper=11.0, linear=True)

prob.set_solver_print(level=0)

prob.driver = om.ScipyOptimizeDriver(optimizer='SLSQP', tol=1e-9, disp=False)

prob.setup(check=False, mode='fwd')

prob.run_driver()
print(prob.get_val('obj'))
print(prob.get_val('z'))
print(prob.get_val('x'))
[3.18339401]
[1.97763890e+00 8.40262301e-15]
[3.61684e-16]