Source code for openmdao.components.jax_implicit_comp

"""
An ImplicitComponent that uses JAX for derivatives.
"""

import sys
import inspect
from types import MethodType
from itertools import chain

from openmdao.core.implicitcomponent import ImplicitComponent
from openmdao.utils.om_warnings import issue_warning
from openmdao.utils.jax_utils import jax, jit, _jax_register_pytree_class, \
    _compute_sparsity, get_vmap_tangents, _update_subjac_sparsity, \
    _jax_derivs2partials, _ensure_returns_tuple


[docs]class JaxImplicitComponent(ImplicitComponent): """ Base class for implicit components when using JAX for derivatives. Parameters ---------- fallback_derivs_method : str The method to use if JAX is not available. Default is 'fd'. **kwargs : dict Additional arguments to be passed to the base class. Attributes ---------- _tangents : dict The tangents for the inputs and outputs. _sparsity : coo_matrix or None The sparsity of the Jacobian. _jac_func_ : function or None The function that computes the jacobian. _orig_compute_primal : function The original compute_primal method. _ret_tuple_compute_primal : function The compute_primal method that returns a tuple. """
[docs] def __init__(self, fallback_derivs_method='fd', **kwargs): # noqa if sys.version_info < (3, 9): raise RuntimeError("JaxImplicitComponent requires Python 3.9 or newer.") super().__init__(**kwargs) self._tangents = {'fwd': None, 'rev': None} self._sparsity = None self._orig_compute_primal = self.compute_primal self._ret_tuple_compute_primal = \ MethodType(_ensure_returns_tuple(self.compute_primal.__func__), self) self.compute_primal = self._ret_tuple_compute_primal # if derivs_method is explicitly passed in, just use it if 'derivs_method' in kwargs: return if jax: self.options['derivs_method'] = 'jax' else: issue_warning(f"{self.msginfo}: JAX is not available, so " f"'{fallback_derivs_method}' will be used for derivatives.") self.options['derivs_method'] = fallback_derivs_method
def _setup_jax(self): self._sparsity = None if self.compute_primal is None: raise RuntimeError(f"{self.msginfo}: compute_primal is not defined for this component.") if self.matrix_free: if self._coloring_info.use_coloring(): issue_warning(f"{self.msginfo}: coloring has been set but matrix_free is True, " "so coloring will be ignored.") self._coloring_info.deactivate() self.apply_linear = self._jax_apply_linear else: if self._coloring_info.use_coloring(): # ensure coloring (and sparsity) is computed before partials raise RuntimeError("Coloring not currently supported for JAX implicit components.") # self._get_coloring() # if self.best_partial_deriv_direction() == 'fwd': # self.linearize = self._jacfwd_colored # else: # self.linearize = self._jacrev_colored else: if not self._subjacs_info: # auto determine subjac sparsities self.compute_sparsity() self.linearize = self._jax_linearize self._has_linearize = True if self.options['use_jit']: static_argnums = [] idx = len(self._var_rel_names['input']) + len(self._var_rel_names['output']) + 1 static_argnums.extend(range(idx, idx + len(self._discrete_inputs))) self.compute_primal = MethodType(jit(self.compute_primal.__func__, static_argnums=static_argnums), self) _jax_register_pytree_class(self.__class__) def _check_compute_primal_args(self): """ Check that the compute_primal method args are in the correct order. """ args = list(inspect.signature(self._orig_compute_primal).parameters) if args and args[0] == 'self': args = args[1:] compargs = self._get_compute_primal_argnames() if args != compargs: raise RuntimeError(f"{self.msginfo}: compute_primal method args {args} don't match " f"the args {compargs} mapped from this component's inputs. To " "map inputs to the compute_primal method, set the name used in " "compute_primal to the 'primal_name' arg when calling " "add_input/add_discrete_input. This is only necessary if the " "declared component input name is not a valid Python name.") def _get_jacobian_func(self): # TODO: modify this to use relevance and possibly compile multiple jac functions depending # on DV/response so that we don't compute any derivatives that are always zero. if self._jac_func_ is None: fjax = jax.jacfwd if self.best_partial_deriv_direction() == 'fwd' else jax.jacrev wrt_idxs = list(range(1, len(self._var_abs2meta['input']) + len(self._var_abs2meta['output']) + 1)) primal_func = self.compute_primal.__func__ self._jac_func_ = MethodType(fjax(primal_func, argnums=wrt_idxs), self) if self.options['use_jit']: static_argnums = tuple(range(1 + len(wrt_idxs), 1 + len(wrt_idxs) + len(self._discrete_inputs))) self._jac_func_ = MethodType(jit(self._jac_func_.__func__, static_argnums=static_argnums), self) return self._jac_func_ def _update_subjac_sparsity(self, sparsity_iter): if self.options['derivs_method'] == 'jax': _update_subjac_sparsity(sparsity_iter, self.pathname, self._subjacs_info) else: super()._update_subjac_sparsity(sparsity_iter)
[docs] def compute_sparsity(self, direction=None, num_iters=1, perturb_size=1e-9): """ Get the sparsity of the Jacobian. Parameters ---------- direction : str The direction to compute the sparsity for. num_iters : int The number of times to run the perturbation iteration. perturb_size : float The size of the perturbation to use. Returns ------- coo_matrix The sparsity of the Jacobian. """ if self._sparsity is None: self._sparsity = _compute_sparsity(self, direction, num_iters, perturb_size) return self._sparsity
def _get_tangents(self, direction): if self._tangents[direction] is None: if direction == 'fwd': self._tangents[direction] = get_vmap_tangents(tuple(chain(self._inputs.values(), self._outputs.values())), direction, fill=1.) else: self._tangents[direction] = get_vmap_tangents(tuple(self._outputs.values()), direction, fill=1.) return self._tangents[direction]
[docs] def declare_coloring(self, **kwargs): """ Declare coloring for this component. The 'method' argument is set to 'jax' and passed to the base class. Parameters ---------- **kwargs : dict Additional arguments to be passed to the base class. """ kwargs['method'] = 'jax' super().declare_coloring(**kwargs)
def _jax_linearize(self, inputs, outputs, partials, discrete_inputs=None, discrete_outputs=None): """ Compute sub-jacobian parts for an implicit component. The model is assumed to be in an unscaled state. Parameters ---------- inputs : Vector Unscaled, dimensional input variables read via inputs[key]. outputs : Vector Unscaled, dimensional output variables read via outputs[key]. partials : partial Jacobian Sub-jac components written to jacobian[output_name, input_name]. discrete_inputs : dict or None If not None, dict containing discrete input values. discrete_outputs : dict or None If not None, dict containing discrete output values. """ derivs = self._get_jacobian_func()(*self._get_compute_primal_invals(inputs, outputs, discrete_inputs)) _jax_derivs2partials(self, derivs, partials, self._var_rel_names['output'], chain(self._var_rel_names['input'], self._var_rel_names['output'])) def _jax_apply_linear(self, inputs, outputs, d_inputs, d_outputs, d_residuals, mode): r""" Compute jac-vector product (implicit). The model is assumed to be in an unscaled state. If mode is: 'fwd': (d_inputs, d_outputs) \|-> d_residuals 'rev': d_residuals \|-> (d_inputs, d_outputs) Parameters ---------- inputs : Vector Unscaled, dimensional input variables read via inputs[key]. outputs : Vector Unscaled, dimensional output variables read via outputs[key]. d_inputs : Vector See inputs; product must be computed only if var_name in d_inputs. d_outputs : Vector See outputs; product must be computed only if var_name in d_outputs. d_residuals : Vector See outputs. mode : str Either 'fwd' or 'rev'. """ if mode == 'fwd': dx = tuple(chain(d_inputs.values(), d_outputs.values())) full_invals = tuple(self._get_compute_primal_invals(inputs, outputs, self._discrete_inputs)) x = full_invals[:len(dx)] other = full_invals[len(dx):] _, deriv_vals = jax.jvp(lambda *args: self.compute_primal(*args, *other), primals=x, tangents=dx) if isinstance(deriv_vals, tuple): d_residuals.set_vals(deriv_vals) else: d_residuals.asarray()[:] = deriv_vals.flatten() else: inhash = (inputs.get_hash(), outputs.get_hash()) + tuple(self._discrete_inputs.values()) if inhash != self._vjp_hash: # recompute vjp function only if inputs or outputs have changed dx = tuple(chain(d_inputs.values(), d_outputs.values())) full_invals = tuple(self._get_compute_primal_invals(inputs, outputs, self._discrete_inputs)) x = full_invals[:len(dx)] other = full_invals[len(dx):] _, self._vjp_fun = jax.vjp(lambda *args: self.compute_primal(*args, *other), *x) self._vjp_hash = inhash if self._compute_primals_out_shape is None: shape = jax.eval_shape(lambda *args: self.compute_primal(*args, *other), *x) if isinstance(shape, tuple): shape = (tuple(s.shape for s in shape), True, len(self._var_rel_names['input'])) else: shape = (shape.shape, False, len(self._var_rel_names['input'])) self._compute_primals_out_shape = shape shape, istup, ninputs = self._compute_primals_out_shape if istup: deriv_vals = (self._vjp_fun(tuple(d_residuals.values()))) else: deriv_vals = self._vjp_fun(tuple(d_residuals.values())[0]) d_inputs.set_vals(deriv_vals[:ninputs]) d_outputs.set_vals(deriv_vals[ninputs:])