Computing Partial Derivatives of Components Using JAX#

To truly take advantage of OpenMDAO, the user needs to compute partial derivatives for any Component that they write. This can be done using finite difference, but that can have issues with accuracy and performance. Using complex step is another option which has good accuracy but isn’t always possible because it requires the component’s computations to be compatible with complex numbers. In some cases, the user can provide analytic partial derivatives, which likely has good performance but can be difficult to determine depending on the complexity of the component.

This notebook describes another method, which is to use the optional third-party JAX library, to automatically differentiate native Python and NumPy functions. To simplify jax usage within OpenMDAO, we’ve created two component classes, JaxExplicitComponent and JaxImplicitComponent. These components require only the definition of a compute_primal method that replaces the compute method for JaxExplicitComponent and the apply_nonlinear method for JaxImplicitComponent.

This notebook will describe in more detail how to create and use a JaxExplicitComponent or JaxImplicitComponent and will give examples.

Before going further, it’s a good idea to aquaint yourself with some of jax’s ‘sharp edges’ here. This will hopefully make the process of creating a JaxExplicitComponent or JaxImplicitComponent a less frustrating one.

The use of JAX is optional for OpenMDAO so if not already installed, the user needs to install it by issuing one of the following commands at your operating system command prompt:

pip install jax jaxlib
pip install openmdao[jax]
pip install openmdao[all]
!pip install jax
Requirement already satisfied: jax in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (0.5.3)
Requirement already satisfied: jaxlib<=0.5.3,>=0.5.3 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.3)
Requirement already satisfied: ml_dtypes>=0.4.0 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (0.5.1)
Requirement already satisfied: numpy>=1.25 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.26.4)
Requirement already satisfied: opt_einsum in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (3.4.0)
Requirement already satisfied: scipy>=1.11.1 in /usr/share/miniconda/envs/test/lib/python3.11/site-packages (from jax) (1.13.1)

The JAX library includes a NumPy-like API, jax.numpy, which implements the NumPy API using the primitives in JAX. Almost anything that can be done with NumPy can be done with jax.numpy. JAX arrays are similar to NumPy arrays, but they are designed to work with accelerators such as GPUs and TPUs. jax.numpy is typically imported as jnp.

import jax

The default for JAX is to do single precision computations. OpenMDAO uses double precision, so this line of code is needed.

jax.config.update("jax_enable_x64", True)

Automatic Determination of Derivative Direction#

JaxExplicitComponent and JaxImplicitComponent automatically determine the direction they will use to compute their partial jacobians based on their jacobian’s shape. If there are more rows than columns in the jacobian, they’ll use forward mode. Otherwise they’ll use reverse mode. The number of columns in the JaxExplicitComponent’s jacobian is equal to the size of its inputs vector, and the number of columns in the JaxImplicitComponent’s jacobian is equal to combined size of its inputs and outputs vectors. Note that this automatic determination of derivative direction only occurs if the matrix_free attribute is False.

Self Statics#

When jax compiles a function, it assumes that the only variables that can change are those that are passed into the function as arguments and any internal variables that depend on those arguments.
All other variables are treated as static. But what if our jax component has an option or attribute that contributes to the output of our compute_primal function? Since that option or attribute doesn’t get passed into the function as an argument, jax doesn’t know about it. In that case, we must be able to detect when those ‘static’ options or attributes change so that we can tell jax to recompile the function. Otherwise the outputs of the function won’t reflect the current values of the static options and attributes.

In JaxExplicitComponent and JaxImplicitComponent, we add a method called get_self_statics to handle this situation. get_self_statics is a simple method that returns a tuple containing any option or attribute in your component that will affect the output of your compute_primal method. If your component doesn’t have any of these ‘self static’ variables then you don’t have to define get_self_statics.

Here’s a simple example. Suppose my component has an option called ‘mult1’ and an attribute called ‘mult2’, and they’re used in compute_primal as follows:

def compute_primal(self, x):
    return x * self.options['mult1'] * self.mult2

In this case, we would be required to define the get_self_statics method shown below:

def get_self_statics(self):
    return (self.options['mult1'], self.mult2)

Doing this will allow the compute_primal to be recompiled whenever self.options['mult1'] or self.mult2 change.

Note that not all of a component’s options and/or attributes need to be returned from get_self_statics. Only those that are referenced inside of compute_primal and affect its outputs should be returned.

Note also that you do not return any variable that you’ve added to your component via add_input, add_output, add_discrete_input, or add_discrete_output, even if you think that they won’t change during a run for some reason. Jax already knows about all of them and can handle changes to them properly.

Configuration Options#

JaxExplicitComponent and JaxImplicitComponent both have the following options:

Debugging#

While normally you want the use_jit option to be True for performance reasons, if you want to debug your compute_primal method it often helps to set use_jit to False. This will allow you to put print statements in your compute_primal or to set breakpoints inside it with a python debugger.

Examples#