Source code for openmdao.utils.jax_utils

"""
Utilities for the use of jax in combination with OpenMDAO.
"""
from collections.abc import Callable

try:
    import jax
except ImportError:
    jax = None


[docs]def jit_stub(f, *args, **kwargs): """ Provide a dummy jit decorator for use if jax is not available. Parameters ---------- f : Callable The function or method to be wrapped. *args : list Positional arguments. **kwargs : dict Keyword arguments. Returns ------- Callable The decorated function. """ return f
[docs]def register_jax_component(comp_class): """ Provide a class decorator that registers the given class as a pytree_node. This allows jax to use jit compilation on the methods of this class if they reference attributes of the class itself, such as `self.options`. Note that this decorator is not necessary if the given class does not reference `self` in any methods to which `jax.jit` is applied. Parameters ---------- comp_class : class The decorated class. Returns ------- object The same class given as an argument. Raises ------ NotImplementedError If this class does not define the `_tree_flatten` and _tree_unflatten` methods. RuntimeError If jax is not available. """ if jax is None: raise RuntimeError("jax is not available. " "Try 'pip install openmdao[jax]' with Python>=3.8.") if not hasattr(comp_class, '_tree_flatten'): raise NotImplementedError(f'class {comp_class} does not implement method _tree_flatten.' f'\nCannot register {comp_class} as a jax jit-compatible ' f'component.') if not hasattr(comp_class, '_tree_unflatten'): raise NotImplementedError(f'class {comp_class} does not implement method _tree_unflatten.' f'\nCannot register class {comp_class} as a jax jit-compatible ' f'component.') jax.tree_util.register_pytree_node(comp_class, comp_class._tree_flatten, comp_class._tree_unflatten) return comp_class