Source code for openmdao.components.func_comp_common

"""
Define functions and objects common to the ExplicitFuncComp and ImplicitFuncComp classes.
"""

import sys
import traceback
import re
from functools import partial
from collections.abc import Iterable

import numpy as np
try:
    import jax
    from jax import vmap
    import jax.numpy as jnp
    # linear_util moved to jax.extend in jax 0.4.17, previous location is deprecated
    try:
        from jax.extend import linear_util
    except ImportError:
        from jax import linear_util
    from jax.api_util import argnums_partial
    from jax._src.api import _jvp, _vjp
    jax.config.update("jax_enable_x64", True)  # jax by default uses 32 bit floats
except Exception:
    _, err, tb = sys.exc_info()
    if not isinstance(err, ImportError):
        traceback.print_tb(tb)
    jax = None


# regex to check for variable names.
namecheck_rgx = re.compile('[_a-zA-Z][_a-zA-Z0-9]*')

# Names that are not allowed for input or output variables (keywords for options)
_disallowed_varnames = {
    'units', 'shape', 'shape_by_conn', 'run_root_only', 'distributed', 'assembled_jac_type'
}


def _copy_with_ignore(dct, keep, ignore=()):
    """
    Copy the entries in the given dict whose keys are in keep.

    Parameters
    ----------
    dct : dict
        The dictionary to be copied.
    keep : set-like
        Set of keys for entries we want to keep.
    ignore : set or tuple
        Don't issue a warning for these non-keeper keys.

    Returns
    -------
    dict
        A new dict containing 'keep' entries.
    """
    return {k: v for k, v in dct.items() if k in keep and k not in ignore}


def _check_var_name(comp, name):
    match = namecheck_rgx.match(name)
    if match is None or match.group() != name:
        raise NameError(f"{comp.msginfo}: '{name}' is not a valid variable name.")

    if name in _disallowed_varnames:
        raise NameError(f"{comp.msginfo}: cannot use variable name '{name}' because "
                        "it's a reserved keyword.")


[docs]def jac_forward(fun, argnums, tangents): """ Similar to the jax.jacfwd function but allows specification of the tangent matrix. This allows us to generate a compressed jacobian based on coloring. Parameters ---------- fun : function The function to be differentiated. argnums : tuple of int or None Specifies which positional args are dynamic. None means all positional args are dynamic. tangents : ndarray Array of 1.0's and 0's that is used to compute the value of the jacobian matrix. Returns ------- function If there are multiple output variables, returns a function that returns rows of the jacobian grouped by output variable, e.g., if there were 2 output variables of size 3 and 4, the function would return a list with two entries. The first entry would contain the first 3 rows of J and the second would contain the next 4 rows of J. If there is only 1 output variable, the values returned are grouped by input variable. """ f = linear_util.wrap_init(fun) if argnums is None: def jacfunf(*args): return vmap(partial(_jvp, f, args), out_axes=(None, -1))(tangents)[1] else: def jacfunf(*args): f_partial, dyn_args = argnums_partial(f, argnums, args) return vmap(partial(_jvp, f_partial, dyn_args), out_axes=(None, -1))(tangents)[1] return jacfunf
[docs]def jac_reverse(fun, argnums, tangents): """ Similar to the jax.jacrev function but allows specification of the tangent matrix. This allows us to generate a compressed jacobian based on coloring. Parameters ---------- fun : function The function to be differentiated. argnums : tuple of int or None Specifies which positional args are dynamic. None means all positional args are dynamic. tangents : ndarray Array of 1.0's and 0's that is used to compute the value of the jacobian matrix. Returns ------- function A function that returns rows of the jacobian grouped by function input variable, e.g., if there were 3 input variables of size 5 and 7 and 9, the function would return a list with 3 entries. The first entry would contain the first 5 columns of J, the second the next 7 columns of J, and the third the next 9 columns of J. Note that for implicit systems, the function inputs will contain both inputs and outputs in the context of OpenMDAO. """ f = linear_util.wrap_init(fun) if argnums is None: def jacfunr(*args): return vmap(_vjp(f, *args)[1])(tangents) else: def jacfunr(*args): f_partial, dyn_args = argnums_partial(f, argnums, args) return vmap(_vjp(f_partial, *dyn_args)[1])(tangents) return jacfunr
[docs]def jacvec_prod(fun, argnums, invals, tangent): """ Similar to the jvp function but gives back a flat column. Note: this is significantly slower (when producing a full jacobian) than jac_forward. Parameters ---------- fun : function The function to be differentiated. argnums : tuple of int or None Specifies which positional args are dynamic. None means all positional args are dynamic. invals : tuple of float or ndarray Dynamic function input values. tangent : ndarray Array of 1.0's and 0's that is used to compute a column of the jacobian matrix. Returns ------- function A function to compute the jacobian vector product. """ f = linear_util.wrap_init(fun) if argnums is not None: invals = list(argnums_partial(f, argnums, invals)[1]) # compute shaped tangents to use later sizes = np.array([jnp.size(a) for a in invals]) inds = np.cumsum(sizes[:-1]) shaped_tangents = [a.reshape(s.shape) for a, s in zip(np.split(tangent, inds, axis=0), invals)] if argnums is None: def jvfun(inps): return _jvp(f, inps, shaped_tangents)[1] else: def jvfun(inps): f_partial, dyn_args = argnums_partial(f, argnums, inps) return _jvp(f_partial, list(dyn_args), shaped_tangents)[1] return jvfun
def _get_tangents(vals, direction, coloring=None, argnums=None, trans=None): """ Return a tuple of tangents values for use with vmap. Parameters ---------- vals : list List of function input values. direction : str Derivative computation direction ('fwd' or 'rev'). coloring : Coloring or None If not None, the Coloring object used to compute a compressed tangent array. argnums : list of int or None Indices of dynamic (differentiable) function args. trans : ndarray Translation array from jacobian indices into function arg indices. This is needed because OpenMDAO expects ordering to be outputs first, then inputs, but function args could be in any order. Returns ------- tuple of ndarray or ndarray The tangents values to be passed to vmap. """ if argnums is None: leaves = vals else: leaves = [vals[i] for i in argnums] sizes = [np.size(a) for a in leaves] inds = np.cumsum(sizes[:-1]) if coloring is None: tangent = np.eye(np.sum(sizes)) if trans is not None: tangent = tangent[:, trans] else: tangent = coloring.tangent_matrix(direction, trans=trans) shapes = [tangent.shape[:1] + np.shape(v) for v in leaves] tangents = tuple([np.reshape(a, shp) for a, shp in zip(np.split(tangent, inds, axis=1), shapes)]) if len(leaves) == 1: tangents = tangents[0] return tangents def _ensure_iter(val): """ Turn the given value into an iterator if it is not already. Parameters ---------- val : object The value to be iterated over. Returns ------- tuple or iterable """ if isinstance(val, Iterable): return val return val,