Computing Partial Derivatives using JaxExplicitComponent#
This notebook gives an example of using JAX to do automatic differentiation (AD) for the Sellar example.
The example contains two JaxExplicitComponents
, SellarDis1
and SellarDis2
. A static option and
a static attribute have been added to SellarDis1 in order to demonstrate how to handle what we call
‘self statics’ in a jax component. Comments interspersed in the code provide some explanations and guidance.
Here is an overview of the steps that need to be taken to make use of JAX for your JaxExplicitComponent
.
Inherit your component from
JaxExplicitComponent
.Write a method named
compute_primal
to compute the outputs from the inputs. This method is the same as what you would normally write for thecompute
method of anExplicitComponent
, but it takes as its arguments the actual individual input variables rather than a dictionary of the inputs, and returns the outputs as a tuple. This allows us to use jax’s AD capabilities on this method. Ordering of the inputs and outputs is critical. The order of the inputs passed into the method and the outputs returned from the method must match the order that they are declared as inputs and outputs in the component. If the don’t, an exception will be raised. Also, discrete inputs, if any, are passed individually as arguments after the continuous variables.By default your component will jit the compute_primal method. If for some reason you don’t want this, then you can set
self.options['use_jit']
to False. This can be useful when debugging as it allows you to put print statements inside of yourcompute_primal
method.For a typical component like
SellarDis2
below, that’s it. You can skip step 5.However, if your
compute_primal
depends on variables that are ‘static’ according to jax, i.e., they affect the output of your compute_primal but are not passed in as arguments, you’ll need to define aget_self_statics
method on your component that returns a tuple containing all such variables. The returned tuple must be hashable. If these static values ever change, jax will recompile thecompute_primal
function, assuming ‘use_jit’ is True. InSellarDis1
below, there is one static attribute,self.staticvar
, and one static option variable,self.options['static_opt']
.
Sellar Example#
The following code defines a model containing two Sellar disciplines, SellarDis1
and SellarDis2
.
import openmdao.api as om
import numpy as np
import jax
import jax.numpy as jnp
class SellarDis1(om.JaxExplicitComponent):
def initialize(self):
# Added this option to this model to demonstrate how having options that affect the output
# of compute_primal requires special care when using jit. See comments below
self.options.declare('static_opt', types=(float,), default=1.)
# Added this to show how to handle a static attribute that affects the output of
# compute_primal.
self.staticvar = 1.
def setup(self):
# Global Design Variable
self.add_input('z', val=jnp.zeros(2))
# Local Design Variable
self.add_input('x', val=0.)
# Coupling parameter
self.add_input('y2', val=1.0)
# Coupling output
self.add_output('y1', val=1.0, lower=0.1, upper=1000., ref=0.1)
# because our compute primal output depends on static variables, in this case self.staticvar
# and self.options['static_opt'], we must define a get_self_statics method. This method must
# return a tuple of all static variables that affect the output of compute_primal. Their order
# in the tuple doesn't matter. If your component happens to have discrete inputs, do NOT return
# them here. Discrete inputs would be passed into the compute_primal function individually, after
# the continuous variables, but we don't have any discrete inputs in this example.
def get_self_statics(self):
# return value must be hashable. Note that if we only had one static variable we would
# still need to return a tuple containing that variable and so would need to follow the
# variable name with a comma, for example: return (self.staticvar,)
return (self.staticvar, self.options['static_opt'])
def compute_primal(self, z, x, y2):
# Note that we multiply our return value by the static variables self.staticvar and
# self.options['static_opt'] here which means that they do affect the output of
# compute_primal. This is why we have to return them from get_self_statics.
return z[0]**2 + z[1] + x - 0.2*y2*self.staticvar*self.options['static_opt']
The second Sellar JaxExplicitComponent
should be written in a similar way.
class SellarDis2(om.JaxExplicitComponent):
def setup(self):
# Global Design Variable
self.add_input('z', val=jnp.zeros(2))
# Coupling parameter
self.add_input('y1', val=1.0)
# Coupling output
self.add_output('y2', val=1.0, lower=0.1, upper=1000., ref=1.0)
def compute_primal(self, z, y1):
# if y1[0].real < 0.0:
# y1[0] *= -1
# Because of jit, conditionals cannot be used as is, as in the above two lines of code.
# Fortunately, JAX provides control flow primitives to deal with that.
# For if statements, JAX provides the cond function.
# See https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit
# for more information about control flow when using jit
y1 = jax.lax.cond(y1[0].real < 0.0, lambda x : -x, lambda x : x, y1)
return y1**.5 + z[0] + z[1]
The rest of this code is standard OpenMDAO code. The code can be run as normal.
class SellarDerivatives(om.Group):
"""
Group containing the Sellar MDA. This version uses the disciplines with derivatives.
"""
def setup(self):
self.add_subsystem('d1', SellarDis1(), promotes=['x', 'z', 'y1', 'y2'])
self.add_subsystem('d2', SellarDis2(), promotes=['z', 'y1', 'y2'])
obj = self.add_subsystem('obj_cmp', om.ExecComp('obj = x**2 + z[1] + y1 + exp(-y2)', obj=0.0,
x=0.0, z=np.array([0.0, 0.0]), y1=0.0, y2=0.0),
promotes=['obj', 'x', 'z', 'y1', 'y2'])
con1 = self.add_subsystem('con_cmp1', om.ExecComp('con1 = 3.16 - y1', con1=0.0, y1=0.0),
promotes=['con1', 'y1'])
con2 = self.add_subsystem('con_cmp2', om.ExecComp('con2 = y2 - 24.0', con2=0.0, y2=0.0),
promotes=['con2', 'y2'])
# manually declare partials to allow graceful fallback to FD when nested under a higher
# level complex step approximation.
obj.declare_partials(of='*', wrt='*', method='cs')
con1.declare_partials(of='*', wrt='*', method='cs')
con2.declare_partials(of='*', wrt='*', method='cs')
self.set_input_defaults('x', 1.0)
self.set_input_defaults('z', np.array([5.0, 2.0]))
prob = om.Problem()
prob.model = model = SellarDerivatives()
model.add_design_var('z', lower=np.array([-10.0, 0.0]), upper=np.array([10.0, 10.0]))
model.add_design_var('x', lower=0.0, upper=10.0)
model.add_objective('obj')
model.add_constraint('con1', upper=0.0)
model.add_constraint('con2', upper=0.0)
model.add_constraint('x', upper=11.0, linear=True)
prob.set_solver_print(level=0)
prob.driver = om.ScipyOptimizeDriver(optimizer='SLSQP', tol=1e-9, disp=False)
prob.setup(force_alloc_complex=True, check=False, mode='fwd')
prob.run_driver()
print(prob.get_val('obj'))
print(prob.get_val('z'))
print(prob.get_val('x'))
[3.18339401]
[1.97763890e+00 8.40262301e-15]
[3.61684e-16]