Computing Partial Derivatives using JaxExplicitComponent#
One of the barriers to using OpenMDAO is that to truly take advantage of OpenMDAO, the user needs to write code for the analytic partial derivatives of their Components
. To avoid that, users can use the optional third-party JAX library, which can automatically differentiate native Python and NumPy functions.
This notebook gives an example of using JAX to do automatic differentiation (AD) for the Sellar example. Based on the sizes
of inputs vs. outputs for a component, the JaxExplicitComponent
will use either forward mode (for more outputs than inputs) or reverse mode (more inputs than outputs) to compute derivatives. The choice of forward or reverse mode will be done automatically. If a JaxExplicitComponent
’s ‘matrix_free’ attribute is set, then the component will use the jax.jvp
method in forward mode and jax_vjp
in reverse mode.
The JaxExplicitComponent
will also use JAX’s just-in-time (jit) compiling capabilities by default to speed up computations. jit can be disabled by setting self.options[‘use_jit’] to False.
The use of JAX is optional for OpenMDAO so if not already installed, the user needs to install it. See the installation instructions for more information about installing JAX.
!pip install jax
Requirement already satisfied: jax in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (0.4.35)
Requirement already satisfied: jaxlib<=0.4.35,>=0.4.34 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.4.35)
Requirement already satisfied: ml-dtypes>=0.4.0 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.0)
Requirement already satisfied: numpy>=1.24 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.26.4)
Requirement already satisfied: opt-einsum in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.10 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.13.1)
Here are some standard OpenMDAO imports
import numpy as np
import openmdao.api as om
The JAX library includes a NumPy-like API, jax.numpy
, which implements the NumPy API using the primitives in JAX. Almost anything that can be done with NumPy can be done with jax.numpy
. JAX arrays are similar to NumPy arrays, but they are designed to work with accelerators such as GPUs and TPUs.
To use jax.numpy
, it needs to be imported, using the commonly used jnp
abbreviation.
import jax
import jax.numpy as jnp
Here is one of the JaxExplicitComponent
s in the model where derivatives will be computed using JAX. Comments interspersed in the code provide some explanations and guidance. Here is an overview of the steps that need to be taken to make use of JAX for your JaxExplicitComponent
.
Inherit your component from
JaxExplicitComponent
.Write a method named
compute_primal
to compute the outputs from the inputs. This method is the same as what you would normally write for thecompute
method of anExplicitComponent
, but it takes as its arguments the actual individual input variables rather than a dictionary of the inputs, and returns the outputs as a tuple. This allows us to use JAX’s AD capabilities on this method. Ordering of the inputs and outputs is critical. The order of the inputs passed into the method and the outputs returned from the method must match the order that they are declared as inputs and outputs in the component. Also, discrete inputs, if any, are passed individually as arguments after the continuous variables.By default your component will jit the compute_primal method. If for some reason you don’t want this, then you can set self.options[‘use_jit’] to False.
For a typical component like
SellarDis2
below, that’s it. You can skip step 5.However, if your compute_primal depends on variables that are ‘static’, i.e., they don’t change during computation of derivatives, you’ll need to define a
get_self_statics
method on your component that returns a tuple containing all of the static variables that your compute_primal method depends on, excluding discrete input variables. The returned tuple must be hashable. If these static values ever change, jax will recompile thecompute_primal
function, assuming ‘use_jit’ is True. InSellarDis1
below, there is one static attribute,self.staticvar
, and one static option variable,self.options['scaling_ref']
.
class SellarDis1(om.JaxExplicitComponent):
def initialize(self):
# Added this option to this model to demonstrate how having options
# requires special care when using jit. See comments below
self.options.declare('static_opt', types=(float,), default=1.)
self.staticvar = 1.
def setup(self):
# Global Design Variable
self.add_input('z', val=jnp.zeros(2))
# Local Design Variable
self.add_input('x', val=0.)
# Coupling parameter
self.add_input('y2', val=1.0)
# Coupling output
self.add_output('y1', val=1.0, lower=0.1, upper=1000., ref=0.1)
def setup_partials(self):
self.declare_partials('*', '*')
# because our compute primal output depends on static variables, in this case self.staticvar
# and self.options['static_opt'], we must define a get_self_statics method. This method must
# return a tuple of all static variables. Their order in the tuple doesn't matter. If your
# component happens to have discrete inputs, do NOT return them here. Discrete inputs are passed
# into the compute_primal function individually, after the continuous variables.
def get_self_statics(self):
# return value must be hashable
return self.staticvar, self.options['static_opt']
def compute_primal(self, z, x, y2):
return (z[0]**2 + z[1] + x - 0.2*y2*self.staticvar*self.options['static_opt'],)
The second Sellar JaxExplicitComponent
should be written in a similar way.
class SellarDis2(om.JaxExplicitComponent):
def setup(self):
# Global Design Variable
self.add_input('z', val=jnp.zeros(2))
# Coupling parameter
self.add_input('y1', val=1.0)
# Coupling output
self.add_output('y2', val=1.0, lower=0.1, upper=1000., ref=1.0)
def setup_partials(self):
self.declare_partials('*', '*')
def compute_primal(self, z, y1):
# if y1[0].real < 0.0:
# y1[0] *= -1
# Because of jit, conditionals cannot be used as is, as in the above two lines of code.
# Fortunately, JAX provides control flow primitives to deal with that.
# For if statements, JAX provides the cond function.
# See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit
# for more information about control flow when using jit
y1 = jax.lax.cond(y1[0].real < 0.0, lambda x : -x, lambda x : x, y1[0])
return y1**.5 + z[0] + z[1]
The rest of this code is standard OpenMDAO code. The code can be run as normal.
class SellarDerivatives(om.Group):
"""
Group containing the Sellar MDA. This version uses the disciplines with derivatives.
"""
def setup(self):
self.add_subsystem('d1', SellarDis1(), promotes=['x', 'z', 'y1', 'y2'])
self.add_subsystem('d2', SellarDis2(), promotes=['z', 'y1', 'y2'])
obj = self.add_subsystem('obj_cmp', om.ExecComp('obj = x**2 + z[1] + y1 + exp(-y2)', obj=0.0,
x=0.0, z=np.array([0.0, 0.0]), y1=0.0, y2=0.0),
promotes=['obj', 'x', 'z', 'y1', 'y2'])
con1 = self.add_subsystem('con_cmp1', om.ExecComp('con1 = 3.16 - y1', con1=0.0, y1=0.0),
promotes=['con1', 'y1'])
con2 = self.add_subsystem('con_cmp2', om.ExecComp('con2 = y2 - 24.0', con2=0.0, y2=0.0),
promotes=['con2', 'y2'])
# manually declare partials to allow graceful fallback to FD when nested under a higher
# level complex step approximation.
obj.declare_partials(of='*', wrt='*', method='cs')
con1.declare_partials(of='*', wrt='*', method='cs')
con2.declare_partials(of='*', wrt='*', method='cs')
self.set_input_defaults('x', 1.0)
self.set_input_defaults('z', np.array([5.0, 2.0]))
prob = om.Problem()
prob.model = model = SellarDerivatives()
model.add_design_var('z', lower=np.array([-10.0, 0.0]), upper=np.array([10.0, 10.0]))
model.add_design_var('x', lower=0.0, upper=10.0)
model.add_objective('obj')
model.add_constraint('con1', upper=0.0)
model.add_constraint('con2', upper=0.0)
model.add_constraint('x', upper=11.0, linear=True)
prob.set_solver_print(level=0)
prob.driver = om.ScipyOptimizeDriver(optimizer='SLSQP', tol=1e-9, disp=False)
prob.setup(check=False, mode='fwd')
prob.run_driver()
print(prob.get_val('obj'))
print(prob.get_val('z'))
print(prob.get_val('x'))
[3.18339401]
[1.97763890e+00 8.40262301e-15]
[3.61684e-16]