jax_utils.py#

Utilities for the use of jax in combination with OpenMDAO.

openmdao.utils.jax_utils.jit_stub(f, *args, **kwargs)[source]

Provide a dummy jit decorator for use if jax is not available.

Parameters:
fCallable

The function or method to be wrapped.

*argslist

Positional arguments.

**kwargsdict

Keyword arguments.

Returns:
Callable

The decorated function.

openmdao.utils.jax_utils.register_jax_component(comp_class)[source]

Provide a class decorator that registers the given class as a pytree_node.

This allows jax to use jit compilation on the methods of this class if they reference attributes of the class itself, such as self.options.

Note that this decorator is not necessary if the given class does not reference self in any methods to which jax.jit is applied.

Parameters:
comp_classclass

The decorated class.

Returns:
object

The same class given as an argument.

Raises:
NotImplementedError

If this class does not define the _tree_flatten and _tree_unflatten` methods.

RuntimeError

If jax is not available.