"""
An ImplicitComponent that uses JAX for derivatives.
"""
import sys
import inspect
from types import MethodType
from itertools import chain
from functools import partial
from openmdao.core.implicitcomponent import ImplicitComponent
from openmdao.utils.om_warnings import issue_warning
from openmdao.utils.jax_utils import jax, jit, _jax_register_pytree_class, \
_compute_sparsity, get_vmap_tangents, _update_subjac_sparsity, \
_jax_derivs2partials, _ensure_returns_tuple, _update_add_input_kwargs, \
_update_add_output_kwargs, _re_init, _get_differentiable_compute_primal, _uncompress_jac, \
_jax2np, _compute_output_shapes
from openmdao.utils.code_utils import get_return_names, get_function_deps
import openmdao.utils.coloring as coloring_mod
[docs]class JaxImplicitComponent(ImplicitComponent):
"""
Base class for implicit components when using JAX for derivatives.
Parameters
----------
matrix_free : bool
If True, this component will compute derivatives using matrix vector products.
fallback_derivs_method : str
The method to use if JAX is not available. Default is 'fd'.
**kwargs : dict
Additional arguments to be passed to the base class.
Attributes
----------
_tangents : dict
The tangents for the inputs and outputs.
_sparsity : coo_matrix or None
The sparsity of the Jacobian.
_jac_func_ : function or None
The function that computes the jacobian.
_orig_compute_primal : function
The original compute_primal method.
_ret_tuple_compute_primal : function
The compute_primal method that returns a tuple.
"""
[docs] def __init__(self, matrix_free=False, fallback_derivs_method='fd', **kwargs): # noqa
if sys.version_info < (3, 9):
raise RuntimeError("JaxImplicitComponent requires Python 3.9 or newer.")
super().__init__(**kwargs)
self.matrix_free = matrix_free
_re_init(self)
if self.compute_primal is None:
raise RuntimeError(f"{self.msginfo}: compute_primal is not defined for this component.")
self._orig_compute_primal = self.compute_primal
self._ret_tuple_compute_primal = \
MethodType(_ensure_returns_tuple(self.compute_primal.__func__), self)
self.compute_primal = self._ret_tuple_compute_primal
# if derivs_method is explicitly passed in, just use it
if 'derivs_method' in kwargs and kwargs['derivs_method'] != 'jax':
return
if jax:
self.options['derivs_method'] = 'jax'
else:
issue_warning(f"{self.msginfo}: JAX is not available, so "
f"'{fallback_derivs_method}' will be used for derivatives.")
self.options['derivs_method'] = fallback_derivs_method
def _declare_options(self):
"""
Declare options before kwargs are processed in the init method.
"""
super()._declare_options()
self.options.declare('default_to_dyn_shapes', types=bool, default=False,
desc='If True, use dynamic shaping for any variables whose value is '
'scalar and whose shape is not explicitly set. Inputs will use '
'shape_by_conn and outputs will use a compute_shape method based '
'on jax.eval_shape. Default is False.')
self.options.undeclare("distributed")
def _setup_check(self):
"""
Check if inputs and outputs have been added, and if not, determine them from compute_primal.
"""
_re_init(self)
if len(self._var_rel_names['input']) > 0 or len(self._var_rel_names['output']) > 0:
return
if not self._var_rel_names['input']:
for argname in inspect.signature(self._orig_compute_primal).parameters:
self.add_input(argname)
if not self._var_rel_names['output']:
for i, name in enumerate(get_return_names(self._orig_compute_primal)):
if name is None:
name = f'out_{i}'
self.add_output(name)
[docs] def add_output(self, name, **kwargs):
"""
Add an output to the component.
This overrides the base class method to update the kwargs to use dynamic shaping by
default.
Parameters
----------
name : str
The name of the output.
**kwargs : dict
The kwargs to pass to the base class method.
"""
super().add_output(name, **_update_add_output_kwargs(self, name, **kwargs))
def _setup_jax(self):
_jax_register_pytree_class(self.__class__)
if not self._discrete_inputs and not self._discrete_outputs and not self.get_self_statics():
# avoid unnecessary statics checks
self._statics_changed = self._statics_noop
def _check_first_linearize(self):
if self._first_call_to_linearize:
self._first_call_to_linearize = False # only do this once
if not self.matrix_free and self._coloring_info.use_coloring() and \
coloring_mod._use_partial_sparsity:
self._get_coloring()
if self._jacobian is not None:
self._jacobian._restore_approx_sparsity()
elif self._do_sparsity and self.options['derivs_method'] == 'jax':
self.compute_sparsity()
def _setup_partials(self):
"""
Call setup_partials in components.
"""
if self.options['derivs_method'] == 'jax':
if self.matrix_free:
if self._coloring_info.use_coloring():
issue_warning(f"{self.msginfo}: coloring has been set but matrix_free is True, "
"so coloring will be ignored.")
self._coloring_info.deactivate()
self.apply_linear = self._jax_apply_linear
else:
# if user hasn't declared partials, try to infer them from the compute_primal. If
# that fails, declare all partials.
if not self._declared_partials_patterns:
self._do_sparsity = True
try:
deps = list(get_function_deps(self._orig_compute_primal,
self._var_rel_names['output']))
except Exception as err:
issue_warning(f"{self.msginfo}: Couldn't determine function graph for "
f"compute_primal: {err}")
deps = []
if deps:
contvars = set(self._var_rel_names['input'])
contvars.update(self._var_rel_names['output'])
for of, wrt in deps:
if of in contvars and wrt in contvars:
self.declare_partials(of, wrt)
else:
self.declare_partials('*', '*')
self.linearize = self._jax_linearize
self._has_linearize = True
super()._setup_partials()
def _statics_changed(self, discrete_inputs):
"""
Determine if jitting is needed based on changes in static values since the last call.
Parameters
----------
discrete_inputs : dict
dict containing discrete input values.
Returns
-------
bool
Whether jitting is needed.
"""
# if static values change, we need to rejit
inhash = hash((tuple(discrete_inputs) if discrete_inputs else (), self.get_self_statics()))
if inhash != self._static_hash:
self._static_hash = inhash
return True
return False
def _statics_noop(self, discrete_inputs):
"""
Use this function if the component has no discrete inputs or self statics.
Parameters
----------
discrete_inputs : dict
dict containing discrete input values.
Returns
-------
bool
Always returns False.
"""
return False
def _get_jax_compute_primal(self, discrete_inputs, need_jit):
"""
Get the jax version of the compute_primal method.
"""
compute_primal = self._ret_tuple_compute_primal.__func__
if need_jit:
# jit the compute_primal method
idx = self._inputs.nvars() + self._outputs.nvars() + 1
if discrete_inputs:
static_argnums = list(range(idx, idx + len(discrete_inputs)))
else:
static_argnums = []
compute_primal = jit(compute_primal, static_argnums=static_argnums)
return MethodType(compute_primal, self)
def _update_jac_functs(self, discrete_inputs):
"""
Update the jax function that computes the jacobian for this component if necessary.
An update is required if jitting is enabled and any static values have changed.
Parameters
----------
discrete_inputs : dict or None
If not None, dict containing discrete input values.
Returns
-------
tuple
The jax functions (jax_compute_primal, jax_compute_jac). Note that these are not
methods, but rather functions. To make them methods you need to assign
MethodType(function, self) to an attribute of the instance.
"""
need_jit = self.options['use_jit']
if need_jit and self._statics_changed(discrete_inputs):
self._jac_func_ = None
if self._jac_func_ is None:
self.compute_primal = self._get_jax_compute_primal(discrete_inputs, need_jit)
differentiable_cp = _get_differentiable_compute_primal(self, discrete_inputs)
if self._coloring_info.use_coloring():
if self._coloring_info.coloring is None:
# need to dynamically compute the coloring first
self._compute_coloring()
if self.best_partial_deriv_direction() == 'fwd':
self._get_tangents('fwd', self._coloring_info.coloring)
# here we'll use the same inputs and a single tangent vector from the vmap
# batch to compute a single jvp, which corresponds to a column of the
# jacobian (the compressed jacobian in the colored case).
def jvp_at_point(tangent, icontvals):
# [1] is the derivative, [0] is the primal (we don't need the primal)
return jax.jvp(differentiable_cp, icontvals, tangent)[1]
# vectorize over the last axis of the tangent vectors and use the same
# inputs for all cases.
self._jac_func_ = jax.vmap(jvp_at_point, in_axes=[-1, None], out_axes=-1)
self._jac_colored_ = self._jacfwd_colored
else: # rev
def vjp_at_point(cotangent, icontvals):
# Returns primal and a function to compute VJP so just take [1],
# the vjp function
return jax.vjp(differentiable_cp, *icontvals)[1](cotangent)
self._get_tangents('rev', self._coloring_info.coloring)
# Batch over last axis of cotangents
self._jac_func_ = jax.vmap(vjp_at_point, in_axes=[-1, None], out_axes=-1)
self._jac_colored_ = self._jacrev_colored
else:
self._jac_colored_ = None
fjax = jax.jacfwd if self.best_partial_deriv_direction() == 'fwd' else jax.jacrev
wrt_idxs = list(range(len(self._var_abs2meta['input']) +
len(self._var_abs2meta['output'])))
self._jac_func_ = fjax(differentiable_cp, argnums=wrt_idxs)
if need_jit:
self._jac_func_ = jax.jit(self._jac_func_)
[docs] def declare_coloring(self, **kwargs):
"""
Declare coloring for this component.
The 'method' argument is set to 'jax' and passed to the base class.
Parameters
----------
**kwargs : dict
Additional arguments to be passed to the base class.
"""
if 'method' in kwargs and kwargs['method'] != self.options['derivs_method']:
raise ValueError(f"method must be '{self.options['derivs_method']}' for this component "
"but got '{kwargs['method']}'.")
kwargs['method'] = self.options['derivs_method']
super().declare_coloring(**kwargs)
def _jax_linearize(self, inputs, outputs, partials, discrete_inputs=None,
discrete_outputs=None):
"""
Compute sub-jacobian parts for an implicit component.
The model is assumed to be in an unscaled state.
Parameters
----------
inputs : Vector
Unscaled, dimensional input variables read via inputs[key].
outputs : Vector
Unscaled, dimensional output variables read via outputs[key].
partials : partial Jacobian
Sub-jac components written to jacobian[output_name, input_name].
discrete_inputs : dict or None
If not None, dict containing discrete input values.
discrete_outputs : dict or None
If not None, dict containing discrete output values.
"""
discrete_inputs = discrete_inputs.values() if discrete_inputs else ()
self._update_jac_functs(discrete_inputs)
if self._jac_colored_ is not None:
return self._jac_colored_(inputs, outputs, partials)
derivs = self._jac_func_(*chain(inputs.values(), outputs.values()))
_jax_derivs2partials(self, derivs, partials, self._var_rel_names['output'],
chain(self._var_rel_names['input'], self._var_rel_names['output']))
def _jacfwd_colored(self, inputs, outputs, partials):
"""
Compute the forward jacobian using vmap with jvp and coloring.
Parameters
----------
inputs : dict
The inputs to the component.
outputs : dict
The outputs to the component.
partials : dict
The partials to compute.
"""
J = self._jac_func_(self._tangents['fwd'], tuple(chain(inputs.values(), outputs.values())))
partials.set_dense_jac(self, _uncompress_jac(self, _jax2np(J), 'fwd'))
def _jacrev_colored(self, inputs, outputs, partials):
"""
Compute the reverse jacobian using vmap with vjp and coloring.
Parameters
----------
inputs : dict
The inputs to the component.
outputs : dict
The outputs to the component.
partials : dict
The partials to compute.
"""
J = self._jac_func_(self._tangents['rev'], tuple(chain(inputs.values(), outputs.values())))
partials.set_dense_jac(self, _uncompress_jac(self, _jax2np(J).T, 'rev'))
[docs] def compute_sparsity(self, direction=None, num_iters=1, perturb_size=1e-9):
"""
Get the sparsity of the Jacobian.
Parameters
----------
direction : str
The direction to compute the sparsity for.
num_iters : int
The number of times to run the perturbation iteration.
perturb_size : float
The size of the perturbation to use.
Returns
-------
coo_matrix
The sparsity of the Jacobian.
"""
if self._sparsity is None:
if self.options['derivs_method'] == 'jax':
self._sparsity = _compute_sparsity(self, direction, num_iters, perturb_size)[0]
else:
self._sparsity = super().compute_sparsity(direction=direction, num_iters=num_iters,
perturb_size=perturb_size)[0]
return self._sparsity
def _update_subjac_sparsity(self, sparsity_iter):
if self.options['derivs_method'] == 'jax':
_update_subjac_sparsity(sparsity_iter, self.pathname, self._subjacs_info)
else:
super()._update_subjac_sparsity(sparsity_iter)
def _get_tangents(self, direction, coloring=None):
"""
Get the tangents for the inputs and/or outputs.
If coloring is not None, then the tangents will be compressed based on the coloring.
Parameters
----------
direction : str
The direction to get the tangents for.
coloring : Coloring
The coloring to use.
Returns
-------
tuple
The tangents.
"""
if self._tangents[direction] is None:
if direction == 'fwd':
self._tangents[direction] = get_vmap_tangents(tuple(chain(self._inputs.values(),
self._outputs.values())),
direction, fill=1., coloring=coloring)
else:
self._tangents[direction] = get_vmap_tangents(tuple(self._outputs.values()),
direction, fill=1., coloring=coloring)
return self._tangents[direction]
def _jax_apply_linear(self, inputs, outputs, d_inputs, d_outputs, d_residuals, mode):
r"""
Compute jac-vector product (implicit). The model is assumed to be in an unscaled state.
If mode is:
'fwd': (d_inputs, d_outputs) \|-> d_residuals
'rev': d_residuals \|-> (d_inputs, d_outputs)
Parameters
----------
inputs : Vector
Unscaled, dimensional input variables read via inputs[key].
outputs : Vector
Unscaled, dimensional output variables read via outputs[key].
d_inputs : Vector
See inputs; product must be computed only if var_name in d_inputs.
d_outputs : Vector
See outputs; product must be computed only if var_name in d_outputs.
d_residuals : Vector
See outputs.
mode : str
Either 'fwd' or 'rev'.
"""
if mode == 'fwd':
dx = tuple(chain(d_inputs.values(), d_outputs.values()))
full_invals = tuple(self._get_compute_primal_invals(inputs, outputs,
self._discrete_inputs))
x = full_invals[:len(dx)]
other = full_invals[len(dx):]
_, deriv_vals = jax.jvp(lambda *args: self.compute_primal(*args, *other),
primals=x, tangents=dx)
if isinstance(deriv_vals, tuple):
d_residuals.set_vals(deriv_vals)
else:
d_residuals.asarray()[:] = deriv_vals.flatten()
else:
inhash = (inputs.get_hash(), outputs.get_hash()) + tuple(self._discrete_inputs.values())
if inhash != self._vjp_hash:
# recompute vjp function only if inputs or outputs have changed
dx = tuple(chain(d_inputs.values(), d_outputs.values()))
full_invals = tuple(self._get_compute_primal_invals(inputs, outputs,
self._discrete_inputs))
x = full_invals[:len(dx)]
other = full_invals[len(dx):]
_, self._vjp_fun = jax.vjp(lambda *args: self.compute_primal(*args, *other), *x)
self._vjp_hash = inhash
if self._compute_primals_out_shape is None:
shape = jax.eval_shape(lambda *args: self.compute_primal(*args, *other), *x)
if isinstance(shape, tuple):
shape = (tuple(s.shape for s in shape), True,
len(self._var_rel_names['input']))
else:
shape = (shape.shape, False, len(self._var_rel_names['input']))
self._compute_primals_out_shape = shape
shape, istup, ninputs = self._compute_primals_out_shape
if istup:
deriv_vals = (self._vjp_fun(tuple(d_residuals.values())))
else:
deriv_vals = self._vjp_fun(tuple(d_residuals.values())[0])
d_inputs.set_vals(deriv_vals[:ninputs])
d_outputs.set_vals(deriv_vals[ninputs:])
def _get_compute_shape_func(self, name):
return partial(self._compute_output_shape, name)
def _compute_output_shape(self, name, input_shapes):
if self._output_shapes is None:
out_shapes = _compute_output_shapes(self._orig_compute_primal.__func__,
input_shapes)
self._output_shapes = {n: shp for n, shp in zip(self._var_rel_names['output'],
out_shapes)}
return self._output_shapes[name]