jax_utils.py#
Utilities for the use of jax in combination with OpenMDAO.
- openmdao.utils.jax_utils.jit_stub(f, *args, **kwargs)[source]
Provide a dummy jit decorator for use if jax is not available.
- Parameters:
- fCallable
The function or method to be wrapped.
- *argslist
Positional arguments.
- **kwargsdict
Keyword arguments.
- Returns:
- Callable
The decorated function.
- openmdao.utils.jax_utils.register_jax_component(comp_class)[source]
Provide a class decorator that registers the given class as a pytree_node.
This allows jax to use jit compilation on the methods of this class if they reference attributes of the class itself, such as self.options.
Note that this decorator is not necessary if the given class does not reference self in any methods to which jax.jit is applied.
- Parameters:
- comp_classclass
The decorated class.
- Returns:
- object
The same class given as an argument.
- Raises:
- NotImplementedError
If this class does not define the _tree_flatten and _tree_unflatten` methods.
- RuntimeError
If jax is not available.