Source code for openmdao.components.jax_implicit_comp
"""
An ImplicitComponent that uses JAX for derivatives.
"""
import sys
from openmdao.core.implicitcomponent import ImplicitComponent
from openmdao.utils.jax_utils import jax
from openmdao.utils.om_warnings import issue_warning
[docs]class JaxImplicitComponent(ImplicitComponent):
"""
Base class for implicit components when using JAX for derivatives.
Parameters
----------
fallback_derivs_method : str
The method to use if JAX is not available. Default is 'fd'.
**kwargs : dict
Additional arguments to be passed to the base class.
"""
[docs] def __init__(self, fallback_derivs_method='fd', **kwargs): # noqa
if sys.version_info < (3, 9):
raise RuntimeError("JaxImplicitComponent requires Python 3.9 or newer.")
super().__init__(**kwargs)
# if derivs_method is explicitly passed in, just use it
if 'derivs_method' in kwargs:
return
if jax:
self.options['derivs_method'] = 'jax'
else:
issue_warning(f"{self.msginfo}: JAX is not available, so "
f"'{fallback_derivs_method}' will be used for derivatives.")
self.options['derivs_method'] = fallback_derivs_method