Source code for openmdao.components.implicit_func_comp

"""Define the ImplicitFuncComp class."""

import sys
import traceback
from itertools import chain
import numpy as np
from openmdao.core.implicitcomponent import ImplicitComponent
from openmdao.core.constants import INT_DTYPE
import openmdao.func_api as omf
from openmdao.components.func_comp_common import _check_var_name, _copy_with_ignore, \
    jac_forward, jac_reverse, _get_tangents, _ensure_iter
from openmdao.utils.array_utils import shape_to_len

try:
    import jax
    from jax import jit
    jax.config.update("jax_enable_x64", True)  # jax by default uses 32 bit floats
except Exception:
    _, err, tb = sys.exc_info()
    if not isinstance(err, ImportError):
        traceback.print_tb(tb)
    jax = None


[docs]class ImplicitFuncComp(ImplicitComponent): """ An implicit component that wraps a python function. Parameters ---------- apply_nonlinear : function The function to be wrapped by this Component. solve_nonlinear : function or None Optional function to perform a nonlinear solve. linearize : function or None Optional function to compute partial derivatives. solve_linear : function or None Optional function to perform a linear solve. **kwargs : named args Args passed down to ImplicitComponent. Attributes ---------- _apply_nonlinear_func : callable The function wrapper used by this component. _apply_nonlinear_func_jax : callable Function decorated to ensure use of jax numpy. _solve_nonlinear_func : function or None Optional function to do a nonlinear solve. solve_nonlinear : method Local override of _solve_nonlinear method. _solve_linear_func : function or None Optional function to do a linear solve. solve_linear : method Local override of solve_linear method. _linearize_func : function or None Optional function to compute partial derivatives. linearize : method Local override of linearize method. _linearize_info : object Some state information to compute in _linearize_func and pass to _solve_linear_func _tangents : tuple Tuple of parts of the tangent matrix cached for jax derivative computation. _tangent_direction : str Direction of the last tangent computation. _jac2func_inds : ndarray Translation array from jacobian indices to function array indices. """
[docs] def __init__(self, apply_nonlinear, solve_nonlinear=None, linearize=None, solve_linear=None, **kwargs): """ Initialize attributes. """ super().__init__(**kwargs) self._apply_nonlinear_func = omf.wrap(apply_nonlinear) self._solve_nonlinear_func = solve_nonlinear self._solve_linear_func = solve_linear self._linearize_func = linearize self._linearize_info = None self._tangents = None self._tangent_direction = None self._jac2func_inds = None if solve_nonlinear: self.solve_nonlinear = self._user_solve_nonlinear self._has_solve_nl = True if linearize: self.linearize = self._user_linearize if solve_linear: self.solve_linear = self._user_solve_linear if self._apply_nonlinear_func._use_jax: self.options['derivs_method'] = 'jax' # setup requires an undecorated, unjitted function, so do it now if self._apply_nonlinear_func._call_setup: self._apply_nonlinear_func._setup() if self.options['derivs_method'] == 'jax': if jax is None: raise RuntimeError(f"{self.msginfo}: jax is not installed. " "Try 'pip install openmdao[jax]' with Python>=3.8.") self._apply_nonlinear_func_jax = omf.jax_decorate(self._apply_nonlinear_func._f) if self.options['derivs_method'] == 'jax' and self.options['use_jit']: static_argnums = [i for i, m in enumerate(self._apply_nonlinear_func._inputs.values()) if 'is_option' in m] try: with omf.jax_context(self._apply_nonlinear_func._f.__globals__): self._apply_nonlinear_func_jax = jit(self._apply_nonlinear_func_jax, static_argnums=static_argnums) except Exception as err: raise RuntimeError(f"{self.msginfo}: failed jit compile of solve_nonlinear " f"function: {err}")
[docs] def setup(self): """ Define our inputs and outputs. """ optignore = {'is_option'} for name, meta in self._apply_nonlinear_func.get_input_meta(): _check_var_name(self, name) if 'is_option' in meta and meta['is_option']: kwargs = _copy_with_ignore(meta, omf._allowed_declare_options_args, ignore=optignore) self.options.declare(name, **kwargs) else: kwargs = omf._filter_dict(meta, omf._allowed_add_input_args) self.add_input(name, **kwargs) for i, (name, meta) in enumerate(self._apply_nonlinear_func.get_output_meta()): _check_var_name(self, name) kwargs = _copy_with_ignore(meta, omf._allowed_add_output_args, ignore=('resid',)) self.add_output(name, **kwargs)
def _setup_jax(self, from_group=False): # TODO: this is here to prevent the ImplicitComponent base class from trying to do its # own jax setup if derivs_method is 'jax'. We should probably refactor this... pass
[docs] def declare_partials(self, *args, **kwargs): """ Declare information about this component's subjacobians. Parameters ---------- *args : list Positional args to be passed to base class version of declare_partials. **kwargs : dict Keyword args to be passed to base class version of declare_partials. Returns ------- dict Metadata dict for the specified partial(s). """ if self._linearize_func is None and ('method' not in kwargs or kwargs['method'] == 'exact'): raise RuntimeError(f"{self.msginfo}: declare_partials must be called with method equal " "to 'cs', 'fd', or 'jax'.") return super().declare_partials(*args, **kwargs)
def _setup_partials(self): """ Check that all partials are declared. """ kwargs = self._apply_nonlinear_func.get_declare_coloring() if kwargs is not None: self.declare_coloring(**kwargs) for kwargs in self._apply_nonlinear_func.get_declare_partials(): self.declare_partials(**kwargs) super()._setup_partials()
[docs] def apply_nonlinear(self, inputs, outputs, residuals, discrete_inputs=None, discrete_outputs=None): """ R = Ax - b. Parameters ---------- inputs : Vector Unscaled, dimensional input variables read via inputs[key]. outputs : Vector Unscaled, dimensional output variables read via outputs[key]. residuals : Vector Unscaled, dimensional residuals written to via residuals[key]. discrete_inputs : _DictValues or None Dict-like object containing discrete inputs. discrete_outputs : _DictValues or None Dict-like object containing discrete outputs. """ residuals.set_vals(_ensure_iter( self._apply_nonlinear_func(*self._ordered_func_invals(inputs, outputs))))
def _user_solve_nonlinear(self, inputs, outputs): """ Compute outputs. The model is assumed to be in a scaled state. """ self._outputs.set_vals(_ensure_iter( self._solve_nonlinear_func(*self._ordered_func_invals(inputs, outputs)))) def _linearize(self, jac=None, sub_do_ln=False): """ Compute jacobian / factorization. The model is assumed to be in a scaled state. Parameters ---------- jac : Jacobian or None Ignored. sub_do_ln : bool Flag indicating if the children should call linearize on their linear solvers. """ if self.options['derivs_method'] == 'jax': if self._mode != self._tangent_direction: # force recomputation of coloring and tangents self._first_call_to_linearize = True self._tangents = None self._check_first_linearize() self._jax_linearize() if (jac is None or jac is self._assembled_jac) and self._assembled_jac is not None: self._assembled_jac._update(self) else: super()._linearize(jac, sub_do_ln) def _jax_linearize(self): """ Compute the jacobian using jax. This updates self._jacobian. """ func = self._apply_nonlinear_func # argnums specifies which position args are to be differentiated inames = list(func.get_input_names()) argnums = [i for i, m in enumerate(func._inputs.values()) if 'is_option' not in m] if len(argnums) == len(inames): argnums = None # speedup if there are no static args osize = len(self._outputs) isize = len(self._inputs) + osize invals = list(self._ordered_func_invals(self._inputs, self._outputs)) coloring = self._coloring_info['coloring'] if self._mode == 'rev': # use reverse mode to compute derivs outvals = tuple(self._outputs.values()) tangents = self._get_tangents(outvals, 'rev', coloring) if coloring is not None: j = [np.asarray(a).reshape((a.shape[0], shape_to_len(a.shape[1:]))) for a in jac_reverse(self._apply_nonlinear_func_jax, argnums, tangents)(*invals)] j = coloring.expand_jac(np.hstack(self._reorder_col_chunks(j)), 'rev') else: j = [] for a in jac_reverse(self._apply_nonlinear_func_jax, argnums, tangents)(*invals): a = np.asarray(a) if a.ndim < 2: a = a.reshape((a.size, 1)) else: a = a.reshape((a.shape[0], shape_to_len(a.shape[1:]))) j.append(a) j = np.hstack(self._reorder_col_chunks(j)).reshape((osize, isize)) else: if coloring is not None: tangents = self._get_tangents(invals, 'fwd', coloring, argnums, trans=self._get_jac2func_inds(self._inputs, self._outputs)) j = [np.asarray(a).reshape((shape_to_len(a.shape[:-1]), a.shape[-1])) for a in jac_forward(self._apply_nonlinear_func_jax, argnums, tangents)(*invals)] j = coloring.expand_jac(np.vstack(j), 'fwd') else: tangents = self._get_tangents(invals, 'fwd', coloring, argnums) j = [] for a in jac_forward(self._apply_nonlinear_func_jax, argnums, tangents)(*invals): a = np.asarray(a) if a.ndim < 2: a = a.reshape((1, a.size)) else: a = a.reshape((shape_to_len(a.shape[:-1]), a.shape[-1])) j.append(a) j = self._reorder_cols(np.vstack(j).reshape((osize, isize))) self._jacobian.set_dense_jac(self, j) def _user_linearize(self, inputs, outputs, jacobian): """ Calculate the partials of the residual for each balance. Parameters ---------- inputs : Vector Unscaled, dimensional input variables read via inputs[key]. outputs : Vector Unscaled, dimensional output variables read via outputs[key]. jacobian : Jacobian Sub-jac components written to jacobian[output_name, input_name]. """ self._linearize_info = self._linearize_func(*chain(self._ordered_func_invals(inputs, outputs), (jacobian,))) def _user_solve_linear(self, d_outputs, d_residuals, mode): r""" Run solve_linear function if there is one. Parameters ---------- d_outputs : Vector Unscaled, dimensional quantities read via d_outputs[key]. d_residuals : Vector Unscaled, dimensional quantities read via d_residuals[key]. mode : str Derivative solution direction, either 'fwd' or 'rev'. """ if mode == 'fwd': d_outputs.set_vals(_ensure_iter( self._solve_linear_func(*chain(d_residuals.values(), (mode, self._linearize_info))))) else: # rev d_residuals.set_vals(_ensure_iter( self._solve_linear_func(*chain(d_outputs.values(), (mode, self._linearize_info))))) def _ordered_func_invals(self, inputs, outputs): """ Yield function input args in their proper order. In OpenMDAO, states are outputs, but for our some of our functions they are inputs, so this function yields the values of the inputs and states in the same order that they were originally given for the _apply_nonlinear_func. Parameters ---------- inputs : Vector The input vector. outputs : Vector The output vector (contains the states). Yields ------ float or ndarray Value of input or state variable. """ inps = inputs.values() outs = outputs.values() for name, meta in self._apply_nonlinear_func._inputs.items(): if 'is_option' in meta: # it's an option yield self.options[name] elif 'resid' in meta: # it's a state yield next(outs) else: yield next(inps) def _get_jac2func_inds(self, inputs, outputs): """ Return a translation array from jac column indices into function input ordering. Parameters ---------- inputs : Vector The input vector. outputs : Vector The output vector (contains the states). Returns ------- ndarray Index translation array """ if self._jac2func_inds is None: inds = np.arange(len(outputs) + len(inputs), dtype=INT_DTYPE) indict = {} start = end = 0 for n, meta in self._apply_nonlinear_func._inputs.items(): if 'is_option' not in meta: end += shape_to_len(meta['shape']) indict[n] = inds[start:end] start = end inds = [indict[n] for n in chain(outputs, inputs)] self._jac2func_inds = np.concatenate(inds) return self._jac2func_inds def _reorder_col_chunks(self, col_chunks): """ Return jacobian column chunks in correct OpenMDAO order (outputs first, then inputs). This is needed in rev mode because the return values of the jacrev function are ordered based on the order of the function inputs, which may be different than OpenMDAO's required order. Parameters ---------- col_chunks : list of ndarray List of column chunks to be reordered Returns ------- list Chunks in OpenMDAO jacobian order. """ inps = [] ordered_chunks = [] chunk_iter = iter(col_chunks) for meta in self._apply_nonlinear_func._inputs.values(): if 'is_option' in meta: # it's an option pass # skip it (don't include in jacobian) elif 'resid' in meta: # it's a state ordered_chunks.append(next(chunk_iter)) else: inps.append(next(chunk_iter)) return ordered_chunks + inps def _reorder_cols(self, arr, coloring=None): """ Reorder the columns of jacobian row chunks in fwd mode. Parameters ---------- arr : ndarray Jacobian or compressed jacobian. coloring : Coloring or None Coloring object. Returns ------- ndarray Reordered array. """ if coloring is None: trans = self._get_jac2func_inds(self._inputs, self._outputs) return arr[:, trans] else: trans = self._get_jac2func_inds(self._inputs, self._outputs) J = np.zeros(coloring._shape) for col, nzpart, icol in coloring.colored_jac_iter(arr, 'fwd', trans): J[nzpart, icol] = col return J def _get_tangents(self, vals, direction, coloring=None, argnums=None, trans=None): """ Return a tuple of tangents values for use with vmap. Parameters ---------- vals : list List of function input values. direction : str Derivative computation direction ('fwd' or 'rev'). coloring : Coloring or None If not None, the Coloring object used to compute a compressed tangent array. argnums : list of int or None Indices of dynamic (differentiable) function args. trans : ndarray Translation array from jacobian indices into function arg indices. This is needed because OpenMDAO expects ordering to be outputs first, then inputs, but function args could be in any order. Returns ------- tuple of ndarray or ndarray The tangents values to be passed to vmap. """ if self._tangents is None: self._tangents = _get_tangents(vals, direction, coloring, argnums, trans) self._tangent_direction = direction return self._tangents def _compute_coloring(self, recurse=False, **overrides): """ Compute a coloring of the partial jacobian. This assumes that the current System is in a proper state for computing derivatives. It just calls the base class version and then resets the tangents so that after coloring a new set of compressed tangents values can be computed. Parameters ---------- recurse : bool If True, recurse from this system down the system hierarchy. Whenever a group is encountered that has specified its coloring metadata, we don't recurse below that group unless that group has a subsystem that has a nonlinear solver that uses gradients. **overrides : dict Any args that will override either default coloring settings or coloring settings resulting from an earlier call to declare_coloring. Returns ------- list of Coloring The computed colorings. """ ret = super()._compute_coloring(recurse, **overrides) self._tangents = None # reset to compute new colored tangents later return ret