Computing Partial Derivatives of Explicit Components Using JAX#
One of the barriers to using OpenMDAO is that to truly take advantage of OpenMDAO, the user needs to write code for the analytic partial derivatives of their Components
. To avoid that, users can use the optional third-party JAX library, which can automatically differentiate native Python and NumPy functions.
This notebook gives an example of using JAX to do automatic differentiation (AD) for the Sellar example. Only forward mode AD will be used in this example, however other options are:
Forward-mode is better for “tall” Jacobian matrices (more outputs than inputs) whereas reverse-mode is better for “wide” Jacobian matrices (more inputs than outputs).
This notebook also shows how to use JAX’s just-in-time (jit) compiling capabilities to dramatically speed up computations.
The use of JAX is optional for OpenMDAO so if not already installed, the user needs to install it by issuing one of the following commands at your operating system command prompt:
pip install jax jaxlib
pip install openmdao[jax]
pip install openmdao[all]
!pip install jax
Requirement already satisfied: jax in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (0.5.0)
Requirement already satisfied: jaxlib<=0.5.0,>=0.5.0 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.0)
Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.1)
Requirement already satisfied: numpy>=1.25 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.26.4)
Requirement already satisfied: opt_einsum in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.11.1 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.13.1)
Here are some standard OpenMDAO imports
from functools import partial
import numpy as np
import openmdao.api as om
The JAX library includes a NumPy-like API, jax.numpy
, which implements the NumPy API using the primitives in JAX. Almost anything that can be done with NumPy can be done with jax.numpy
. JAX arrays are similar to NumPy arrays, but they are designed to work with accelerators such as GPUs and TPU.
To use jax.numpy
, it needs to be imported, using the commonly used jnp
abbreviation.
import jax
import jax.numpy as jnp
The default for JAX is to do single precision computations. For this example, we want to use double precision, so this line of code is needed.
jax.config.update("jax_enable_x64", True)
Here is one of the ExplicitComponent
s in the model where derivatives will be computed using JAX. Comments interspersed in the code provide some explanations and guidance. Here is an overview of the steps that need to be taken to make use of JAX for your ExplicitComponent
.
NOTE: A newer, experimental API for using JAX with ExplicitComponent
s has been developed that simplifies the process outlined below. Check it out here
Write a method to compute the outputs from the inputs. Borrowing from AD terminology, a suggested name for this method is
_compute_primal
. This method is the same as what you would normally write for thecompute
method of anExplicitComponent
, but it takes as its arguments the actual individual input variables rather than a dictionary of the inputs. This allows us to use JAX’s AD capabilities on this method. The_compute_primal
method simply returns the outputs as a single value or as a tuple if there are more than one outputs. Apply thejit
decorator to this method to speed up the computations.In the constructor of the
ExplicitComponent
, create an attribute and assign to it a function that will compute the partial derivatives of theExplicitComponent
. This simply makes use of the JAX function,jacfwd
, applied to the_compute_primal
method.Create a method that computes the partials. In this example, it is called
_compute_partials_jacfwd
but could be any name that makes sense to the user. Apply thejit
decorator to this method to speed up the computations.Make use of the
_compute_primal
method in the usual OpenMDAOExplicitComponent.compute
method.Make use of the
_compute_partials_jacfwd
in the usual OpenMDAOExplicitComponent.compute_partials
method.
class SellarDis1(om.ExplicitComponent):
def __init__(self):
super().__init__()
# argnums specifies which positional arguments to differentiate with respect to.
# Here we want derivates with respect to all 3 inputs of _compute_primal.
self.deriv_func_jacfwd = jax.jacfwd(self._compute_primal, argnums=[0, 1, 2])
def initialize(self):
# Added this option to this model to demonstrate how having options
# requires special care when using jit. See comments below
self.options.declare('scaling_ref', types=(float,), default=0.1)
def setup(self):
ref = self.options['scaling_ref']
# Global Design Variable
self.add_input('z', val=np.zeros(2))
# Local Design Variable
self.add_input('x', val=0.)
# Coupling parameter
self.add_input('y2', val=1.0)
# Coupling output
self.add_output('y1', val=1.0, lower=0.1, upper=1000., ref=ref)
def setup_partials(self):
# Finite difference everything
self.declare_partials('*', '*')
# The "partial" decorator returns a new function that has the same body as the original
# function, but with the specified arguments bound.
# Need to tell jit that self is a "static" argument.
# This allows the jitted class to access the options attribute
@partial(jax.jit, static_argnums=(0,))
def _compute_primal(self, z, x, y2):
return z[0]**2 + z[1] + x - 0.2*y2
@partial(jax.jit, static_argnums=(0,))
def _compute_partials_jacfwd(self, z, x, y2):
# Always returns a tuple
dz, dx, dy2 = self.deriv_func_jacfwd(z, x, y2)
return dz, dx, dy2
def compute(self, inputs, outputs):
outputs['y1'] = self._compute_primal(*inputs.values())
def compute_partials(self, inputs, partials):
dz, dx, dy2 = self._compute_partials_jacfwd(*inputs.values())
partials['y1', 'z'] = dz
partials['y1', 'x'] = dx
partials['y1', 'y2'] = dy2
Similarly, the second Sellar ExplicitComponent
should be written in the same way.
class SellarDis2(om.ExplicitComponent):
def __init__(self,):
super().__init__()
self.deriv_func_jacfwd = jax.jacfwd(self._compute_primal, argnums=[0, 1])
def setup(self):
# Global Design Variable
self.add_input('z', val=jnp.zeros(2))
# Coupling parameter
self.add_input('y1', val=1.0)
# Coupling output
self.add_output('y2', val=1.0, lower=0.1, upper=1000., ref=1.0)
def setup_partials(self):
self.declare_partials('*', '*')
@partial(jax.jit, static_argnums=(0,) )
def _compute_primal(self, z, y1):
# Depending on whether this is called via compute or compute_partials, y1 could have
# different dimensions. It's just a scalar though
if np.ndim(y1) == 1:
y1 = y1[0]
# if y1.real < 0.0:
# y1 *= -1
# Because of jit, conditionals cannot be used as is, as in the above two lines of code.
# Fortunately, JAX provides control flow primitives to deal with that.
# For if statements, JAX provided the cond function.
# See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit
# for more information about control flow when using jit
y1 = jax.lax.cond(y1.real < 0.0, lambda y1 : -y1, lambda y1 : y1, y1)
return y1**.5 + z[0] + z[1]
@partial(jax.jit, static_argnums=(0,))
def _compute_partials_jacfwd(self, z, y1):
dz, dy1 = self.deriv_func_jacfwd(z, y1)
return dz, dy1
def compute(self, inputs, outputs):
outputs['y2'] = self._compute_primal(*inputs.values())
def compute_partials(self, inputs, partials):
# pass in y1, which is used in a conditional, as a scalar, which is hashable
z, y1 = inputs.values()
y1 = y1[0]
dz, dy1 = self._compute_partials_jacfwd(z, y1)
partials['y2', 'z'] = dz
partials['y2', 'y1'] = dy1
The rest of this code is standard OpenMDAO code. The code can be run as normal.
class SellarDerivatives(om.Group):
"""
Group containing the Sellar MDA. This version uses the disciplines with derivatives.
"""
def setup(self):
self.add_subsystem('d1', SellarDis1(), promotes=['x', 'z', 'y1', 'y2'])
self.add_subsystem('d2', SellarDis2(), promotes=['z', 'y1', 'y2'])
obj = self.add_subsystem('obj_cmp', om.ExecComp('obj = x**2 + z[1] + y1 + exp(-y2)', obj=0.0,
x=0.0, z=np.array([0.0, 0.0]), y1=0.0, y2=0.0),
promotes=['obj', 'x', 'z', 'y1', 'y2'])
con1 = self.add_subsystem('con_cmp1', om.ExecComp('con1 = 3.16 - y1', con1=0.0, y1=0.0),
promotes=['con1', 'y1'])
con2 = self.add_subsystem('con_cmp2', om.ExecComp('con2 = y2 - 24.0', con2=0.0, y2=0.0),
promotes=['con2', 'y2'])
# manually declare partials to allow graceful fallback to FD when nested under a higher
# level complex step approximation.
obj.declare_partials(of='*', wrt='*', method='cs')
con1.declare_partials(of='*', wrt='*', method='cs')
con2.declare_partials(of='*', wrt='*', method='cs')
self.set_input_defaults('x', 1.0)
self.set_input_defaults('z', np.array([5.0, 2.0]))
prob = om.Problem()
prob.model = model = SellarDerivatives()
model.add_design_var('z', lower=np.array([-10.0, 0.0]), upper=np.array([10.0, 10.0]))
model.add_design_var('x', lower=0.0, upper=10.0)
model.add_objective('obj')
model.add_constraint('con1', upper=0.0)
model.add_constraint('con2', upper=0.0)
model.add_constraint('x', upper=11.0, linear=True)
prob.set_solver_print(level=0)
prob.driver = om.ScipyOptimizeDriver(optimizer='SLSQP', tol=1e-9, disp=False)
prob.setup(check=False, mode='fwd')
prob.run_driver()
print(prob.get_val('obj'))
[3.18339401]