Source code for openmdao.components.jax_explicit_comp

"""
An ExplicitComponent that uses JAX for derivatives.
"""

import sys
from types import MethodType

from openmdao.core.explicitcomponent import ExplicitComponent
from openmdao.utils.om_warnings import issue_warning
from openmdao.utils.jax_utils import jax, jit, \
    _jax_register_pytree_class, _compute_sparsity, get_vmap_tangents, \
    _update_subjac_sparsity, _jax_derivs2partials, _uncompress_jac, _jax2np, \
    _ensure_returns_tuple


[docs]class JaxExplicitComponent(ExplicitComponent): """ Base class for explicit components when using JAX for derivatives. Parameters ---------- fallback_derivs_method : str The method to use if JAX is not available. Default is 'fd'. **kwargs : dict Additional arguments to be passed to the base class. Attributes ---------- _tangents : dict The tangents for the inputs and outputs. _sparsity : coo_matrix or None The sparsity of the Jacobian. _jac_func_ : function or None The function that computes the jacobian. _jac_colored_ : function or None The function that computes the colored jacobian. _static_hash : tuple The hash of the static values. _orig_compute_primal : function The original compute_primal method. _ret_tuple_compute_primal : function The compute_primal method that returns a tuple. """
[docs] def __init__(self, fallback_derivs_method='fd', **kwargs): # noqa if sys.version_info < (3, 9): raise RuntimeError("JaxExplicitComponent requires Python 3.9 or newer.") super().__init__(**kwargs) self._re_init() if self.compute_primal is None: raise RuntimeError(f"{self.msginfo}: compute_primal is not defined for this component.") self._orig_compute_primal = self.compute_primal self._ret_tuple_compute_primal = \ MethodType(_ensure_returns_tuple(self.compute_primal.__func__), self) self.compute_primal = self._ret_tuple_compute_primal # if derivs_method is explicitly passed in, just use it if 'derivs_method' in kwargs: return if jax: self.options['derivs_method'] = 'jax' else: issue_warning(f"{self.msginfo}: JAX is not available, so '{fallback_derivs_method}' " "will be used for derivatives.") self.options['derivs_method'] = fallback_derivs_method
def _re_init(self): """ Re-initialize the component for a new run. """ self._tangents = {'fwd': None, 'rev': None} self._sparsity = None self._jac_func_ = None self._static_hash = None self._jac_colored_ = None def _setup_jax(self): """ Set up the jax interface for this component. This happens in final_setup after all var sizes and partials are set. """ _jax_register_pytree_class(self.__class__) self._re_init() self.compute = self._jax_compute if not self._discrete_inputs and not self.get_self_statics(): # avoid unnecessary statics checks self._statics_changed = self._statics_noop if self.matrix_free: if self._coloring_info.use_coloring(): issue_warning(f"{self.msginfo}: coloring has been set but matrix_free is True, " "so coloring will be ignored.") self._coloring_info.deactivate() self.compute_jacvec_product = self._compute_jacvec_product else: if self._coloring_info.use_coloring(): # ensure coloring (and sparsity) is computed before partials self._get_coloring() else: if not self._declared_partials_patterns: # auto determine subjac sparsities self.compute_sparsity() self.compute_partials = self._compute_partials self._has_compute_partials = True def _statics_changed(self, discrete_inputs): """ Determine if jitting is needed based on changes in static values since the last call. Parameters ---------- discrete_inputs : dict dict containing discrete input values. Returns ------- bool Whether jitting is needed. """ # if static values change, we need to rejit inhash = tuple(discrete_inputs) if discrete_inputs else () inhash = inhash + self.get_self_statics() if inhash != self._static_hash: self._static_hash = inhash return True return False def _statics_noop(self, discrete_inputs): """ Use this function if the component has no discrete inputs or self statics. Parameters ---------- discrete_inputs : dict dict containing discrete input values. Returns ------- bool Always returns False. """ return False def _jax_compute(self, inputs, outputs, discrete_inputs=None, discrete_outputs=None): """ Compute outputs given inputs. The model is assumed to be in an unscaled state. An inherited component may choose to either override this function or to define a compute_primal function. Parameters ---------- inputs : Vector Unscaled, dimensional input variables read via inputs[key]. outputs : Vector Unscaled, dimensional output variables read via outputs[key]. discrete_inputs : dict-like or None If not None, dict-like object containing discrete input values. discrete_outputs : dict-like or None If not None, dict-like object containing discrete output values. """ if discrete_outputs: returns = self.compute_primal(*self._get_compute_primal_invals(inputs, discrete_inputs)) outputs.set_vals(returns[:outputs.nvars()]) self._discrete_outputs.set_vals(returns[outputs.nvars():]) else: outputs.set_vals(self.compute_primal(*self._get_compute_primal_invals(inputs, discrete_inputs))) def _get_differentiable_compute_primal(self, discrete_inputs): """ Get the compute_primal function for the jacobian. This version of the compute primal should take no discrete inputs and return no discrete outputs. It will be called when computing the jacobian. Parameters ---------- discrete_inputs : iter of discrete values The discrete input values. Returns ------- function The compute_primal function to be used to compute the jacobian. """ # exclude the discrete inputs from the inputs and the discrete outputs from the outputs if discrete_inputs: if self._discrete_outputs: ncontouts = self._outputs.nvars() def differentiable_compute_primal(*contvals): return self.compute_primal(*contvals, *discrete_inputs)[:ncontouts] else: def differentiable_compute_primal(*contvals): return self.compute_primal(*contvals, *discrete_inputs) return differentiable_compute_primal elif self._discrete_outputs: ncontouts = self._outputs.nvars() def differentiable_compute_primal(*contvals): return self.compute_primal(*contvals)[:ncontouts] return differentiable_compute_primal return self.compute_primal def _get_jax_compute_primal(self, discrete_inputs, need_jit): """ Get the jax version of the compute_primal method. """ compute_primal = self._ret_tuple_compute_primal.__func__ if need_jit: # jit the compute_primal method idx = self._inputs.nvars() + 1 if discrete_inputs: static_argnums = list(range(idx, idx + len(discrete_inputs))) else: static_argnums = [] compute_primal = jit(compute_primal, static_argnums=static_argnums) return MethodType(compute_primal, self) def _update_jac_functs(self, discrete_inputs): """ Update the jax function that computes the jacobian for this component if necessary. An update is required if jitting is enabled and any static values have changed. Parameters ---------- discrete_inputs : dict or None If not None, dict containing discrete input values. Returns ------- tuple The jax functions (jax_compute_primal, jax_compute_jac). Note that these are not methods, but rather functions. To make them methods you need to assign MethodType(function, self) to an attribute of the instance. """ need_jit = self.options['use_jit'] if need_jit and self._statics_changed(discrete_inputs): self._jac_func_ = None if self._jac_func_ is None: self.compute_primal = self._get_jax_compute_primal(discrete_inputs, need_jit) differentiable_cp = self._get_differentiable_compute_primal(discrete_inputs) if self._coloring_info.use_coloring(): if self._coloring_info.coloring is None: # need to dynamically compute the coloring first self._compute_coloring() if self.best_partial_deriv_direction() == 'fwd': self._get_tangents('fwd', self._coloring_info.coloring) # here we'll use the same inputs and a single tangent vector from the vmap # batch to compute a single jvp, which corresponds to a column of the # jacobian (the compressed jacobian in the colored case). def jvp_at_point(tangent, icontvals): # [1] is the derivative, [0] is the primal (we don't need the primal) return jax.jvp(differentiable_cp, icontvals, tangent)[1] # vectorize over the last axis of the tangent vectors and use the same # inputs for all cases. self._jac_func_ = jax.vmap(jvp_at_point, in_axes=[-1, None], out_axes=-1) self._jac_colored_ = self._jacfwd_colored else: # rev def vjp_at_point(cotangent, icontvals): # Returns primal and a function to compute VJP so just take [1], # the vjp function return jax.vjp(differentiable_cp, *icontvals)[1](cotangent) self._get_tangents('rev', self._coloring_info.coloring) # Batch over last axis of cotangents self._jac_func_ = jax.vmap(vjp_at_point, in_axes=[-1, None], out_axes=-1) self._jac_colored_ = self._jacrev_colored else: self._jac_colored_ = None fjax = jax.jacfwd if self.best_partial_deriv_direction() == 'fwd' else jax.jacrev wrt_idxs = list(range(len(self._var_abs2meta['input']))) self._jac_func_ = fjax(differentiable_cp, argnums=wrt_idxs) if need_jit: self._jac_func_ = jax.jit(self._jac_func_)
[docs] def declare_coloring(self, **kwargs): """ Declare coloring for this component. The 'method' argument is set to 'jax' and passed to the base class. Parameters ---------- **kwargs : dict Additional arguments to be passed to the base class. """ if 'method' in kwargs and kwargs['method'] != self.options['derivs_method']: raise ValueError(f"method must be '{self.options['derivs_method']}' for this component " "but got '{kwargs['method']}'.") kwargs['method'] = self.options['derivs_method'] super().declare_coloring(**kwargs)
# we define _compute_partials here and possibly later rename it to compute_partials instead of # making this the base class version as we did with compute, because the existence of a # compute_partials method that is not the base class method is used to determine if a given # component computes its own partials. def _compute_partials(self, inputs, partials, discrete_inputs=None): """ Compute sub-jacobian parts. The model is assumed to be in an unscaled state. Parameters ---------- self : ImplicitComponent The component instance. inputs : Vector Unscaled, dimensional input variables read via inputs[key]. partials : Jacobian Sub-jac components written to partials[output_name, input_name].. discrete_inputs : dict or None If not None, dict containing discrete input values. """ discrete_inputs = discrete_inputs.values() if discrete_inputs else () self._update_jac_functs(discrete_inputs) if self._jac_colored_ is not None: return self._jac_colored_(inputs, partials, discrete_inputs) derivs = self._jac_func_(*inputs.values()) # check to see if we even need this with jax. A jax component doesn't need to map string # keys to partials. We could just use the jacobian as an array to compute the derivatives. # Maybe make a simple JaxJacobian that is just a thin wrapper around the jacobian array. # The only issue is do higher level jacobians need the subjacobian info? _jax_derivs2partials(self, derivs, partials, self._var_rel_names['output'], self._var_rel_names['input']) def _jacfwd_colored(self, inputs, partials, discrete_inputs=None): """ Compute the forward jacobian using vmap with jvp and coloring. Parameters ---------- inputs : dict The inputs to the component. partials : dict The partials to compute. discrete_inputs : dict or None If not None, dict containing discrete input values. """ J = self._jac_func_(self._tangents['fwd'], tuple(inputs.values())) partials.set_dense_jac(self, _uncompress_jac(self, _jax2np(J), 'fwd')) def _jacrev_colored(self, inputs, partials, discrete_inputs=None): """ Compute the reverse jacobian using vmap with vjp and coloring. Parameters ---------- inputs : dict The inputs to the component. partials : dict The partials to compute. discrete_inputs : dict or None If not None, dict containing discrete input values. """ J = self._jac_func_(self._tangents['rev'], tuple(inputs.values())) partials.set_dense_jac(self, _uncompress_jac(self, _jax2np(J).T, 'rev'))
[docs] def compute_sparsity(self, direction=None, num_iters=1, perturb_size=1e-9): """ Get the sparsity of the Jacobian. Parameters ---------- direction : str The direction to compute the sparsity for. num_iters : int The number of times to run the perturbation iteration. perturb_size : float The size of the perturbation to use. Returns ------- coo_matrix The sparsity of the Jacobian. """ if self._sparsity is None: if self.options['derivs_method'] == 'jax': self._sparsity = _compute_sparsity(self, direction, num_iters, perturb_size) else: self._sparsity = super().compute_sparsity(direction=direction, num_iters=num_iters, perturb_size=perturb_size) return self._sparsity
def _update_subjac_sparsity(self, sparsity_iter): if self.options['derivs_method'] == 'jax': _update_subjac_sparsity(sparsity_iter, self.pathname, self._subjacs_info) else: super()._update_subjac_sparsity(sparsity_iter) def _get_tangents(self, direction, coloring=None): """ Get the tangents for the inputs or outputs. If coloring is not None, then the tangents will be compressed based on the coloring. Parameters ---------- direction : str The direction to get the tangents for. coloring : Coloring The coloring to use. Returns ------- tuple The tangents. """ if self._tangents[direction] is None: if direction == 'fwd': self._tangents[direction] = \ get_vmap_tangents(tuple(self._inputs.values()), direction, fill=1., coloring=coloring) else: self._tangents[direction] = \ get_vmap_tangents(tuple(self._outputs.values()), direction, fill=1., coloring=coloring) return self._tangents[direction] def _compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode, discrete_inputs=None): r""" Compute jac-vector product (explicit). The model is assumed to be in an unscaled state. If mode is: 'fwd': d_inputs \|-> d_outputs 'rev': d_outputs \|-> d_inputs Parameters ---------- self : ExplicitComponent The component instance. inputs : Vector Unscaled, dimensional input variables read via inputs[key]. d_inputs : Vector See inputs; product must be computed only if var_name in d_inputs. d_outputs : Vector See outputs; product must be computed only if var_name in d_outputs. mode : str Either 'fwd' or 'rev'. discrete_inputs : dict or None If not None, dict containing discrete input values. """ if mode == 'fwd': dx = tuple(d_inputs.values()) full_invals = tuple(self._get_compute_primal_invals(inputs, discrete_inputs)) x = full_invals[:len(dx)] other = full_invals[len(dx):] _, deriv_vals = jax.jvp(lambda *args: self.compute_primal(*args, *other), primals=x, tangents=dx) d_outputs.set_vals(deriv_vals) else: inhash = ((inputs.get_hash(),) + tuple(self._discrete_inputs.values()) + self.get_self_statics()) if inhash != self._static_hash: ncont_ins = d_inputs.nvars() full_invals = tuple(self._get_compute_primal_invals(inputs, discrete_inputs)) x = full_invals[:ncont_ins] other = full_invals[ncont_ins:] # recompute vjp function if inputs have changed _, self._vjp_fun = jax.vjp(lambda *args: self.compute_primal(*args, *other), *x) self._static_hash = inhash deriv_vals = self._vjp_fun(tuple(d_outputs.values()) + tuple(self._discrete_outputs.values())) d_inputs.set_vals(deriv_vals)