Computing Partial Derivatives of Components Using JAX#
To truly take advantage of OpenMDAO, the user needs to compute partial derivatives for any Component
that they write. This can be done using finite difference, but that can have issues with accuracy
and performance. Using complex step is another option which has good accuracy but isn’t always
possible because it requires the component’s computations to be compatible with complex numbers.
In some cases, the user can provide analytic partial derivatives, which likely has good performance
but can be difficult to determine depending on the complexity of the component.
This notebook describes another method, which is to use the optional third-party
JAX library, to
automatically differentiate native Python and NumPy functions. To simplify jax usage within OpenMDAO,
we’ve created two component classes, JaxExplicitComponent and
JaxImplicitComponent. These components require only the definition of
a compute_primal
method that replaces the compute
method for JaxExplicitComponent
and the
apply_nonlinear
method for JaxImplicitComponent
.
This notebook will describe in more detail how to create and use a JaxExplicitComponent or JaxImplicitComponent and will give examples.
Before going further, it’s a good idea to aquaint yourself with some of jax’s ‘sharp edges’
here. This will hopefully
make the process of creating a JaxExplicitComponent
or JaxImplicitComponent
a less frustrating one.
The use of JAX is optional for OpenMDAO so if not already installed, the user needs to install it by issuing one of the following commands at your operating system command prompt:
pip install jax jaxlib
pip install openmdao[jax]
pip install openmdao[all]
!pip install jax
Requirement already satisfied: jax in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (0.5.3)
Requirement already satisfied: jaxlib<=0.5.3,>=0.5.3 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.3)
Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.1)
Requirement already satisfied: numpy>=1.25 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.26.4)
Requirement already satisfied: opt_einsum in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.11.1 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.13.1)
The JAX library includes a NumPy-like API, jax.numpy
, which implements the NumPy API using the primitives in JAX. Almost anything that can be done with NumPy can be done with jax.numpy
. JAX arrays are similar to NumPy arrays, but they are designed to work with accelerators such as GPUs and TPUs. jax.numpy
is typically imported as jnp
.
import jax
The default for JAX is to do single precision computations. OpenMDAO uses double precision, so this line of code is needed.
jax.config.update("jax_enable_x64", True)
Automatic Determination of Derivative Direction#
JaxExplicitComponent
and JaxImplicitComponent
automatically determine the direction they will
use to compute their partial jacobians based on their jacobian’s shape. If there are more rows
than columns in the jacobian, they’ll use forward mode. Otherwise they’ll use reverse mode.
The number of columns in the JaxExplicitComponent
’s jacobian is equal to the size of its inputs
vector, and the number of columns in the JaxImplicitComponent
’s jacobian is equal to combined size
of its inputs and outputs vectors.
Note that this automatic determination of derivative direction only occurs if the matrix_free
attribute is False.
Self Statics#
When jax compiles a function, it assumes that the only variables that can change are those that are
passed into the function as arguments and any internal variables that depend on those arguments.
All other variables are treated as static. But what if
our jax component has an option or attribute that contributes to the output of our compute_primal
function? Since that option or attribute doesn’t get passed into the function as an argument, jax
doesn’t know about it. In that case, we must be able to detect when those ‘static’ options or
attributes change so that we can tell jax to recompile the function. Otherwise the outputs of the
function won’t reflect the current values of the static options and attributes.
In JaxExplicitComponent
and JaxImplicitComponent
, we add a method called get_self_statics
to
handle this situation. get_self_statics
is a simple method that returns a tuple containing
any option or attribute in your component that will affect the output of your compute_primal
method.
If your component doesn’t have any of these ‘self static’ variables then you don’t have to define
get_self_statics
.
Here’s a simple example. Suppose my component has an option called ‘mult1’ and an attribute called
‘mult2’, and they’re used in compute_primal
as follows:
def compute_primal(self, x):
return x * self.options['mult1'] * self.mult2
In this case, we would be required to define the get_self_statics
method shown below:
def get_self_statics(self):
return (self.options['mult1'], self.mult2)
Doing this will allow the compute_primal
to be recompiled whenever self.options['mult1']
or
self.mult2
change.
Note that not all of a component’s options and/or attributes need to be returned from get_self_statics
.
Only those that are referenced inside of compute_primal
and affect its outputs should be returned.
Note also that you do not return any variable that you’ve added to your component via add_input
,
add_output
, add_discrete_input
, or add_discrete_output
, even if you think that they won’t
change during a run for some reason. Jax already knows about all of them and can handle changes to
them properly.
Configuration Options#
JaxExplicitComponent and JaxImplicitComponent both have the following options:
Debugging#
While normally you want the use_jit
option to be True for performance reasons, if you want to debug
your compute_primal
method it often helps to set use_jit
to False. This will allow you to put
print statements in your compute_primal
or to set breakpoints inside it with a python debugger.