Source code for openmdao.components.jax_explicit_comp

"""
An ExplicitComponent that uses JAX for derivatives.
"""

import sys
from openmdao.core.explicitcomponent import ExplicitComponent
from openmdao.utils.jax_utils import jax
from openmdao.utils.om_warnings import issue_warning


[docs]class JaxExplicitComponent(ExplicitComponent): """ Base class for explicit components when using JAX for derivatives. Parameters ---------- fallback_derivs_method : str The method to use if JAX is not available. Default is 'fd'. **kwargs : dict Additional arguments to be passed to the base class. """
[docs] def __init__(self, fallback_derivs_method='fd', **kwargs): # noqa if sys.version_info < (3, 9): raise RuntimeError("JaxExplicitComponent requires Python 3.9 or newer.") super().__init__(**kwargs) # if derivs_method is explicitly passed in, just use it if 'derivs_method' in kwargs: return if jax: self.options['derivs_method'] = 'jax' else: issue_warning(f"{self.msginfo}: JAX is not available, so '{fallback_derivs_method}' " "will be used for derivatives.") self.options['derivs_method'] = fallback_derivs_method