"""
Define functions and objects common to the ExplicitFuncComp and ImplicitFuncComp classes.
"""
import sys
import traceback
import re
from functools import partial
import numpy as np
try:
    import jax
    from jax import vmap
    import jax.numpy as jnp
    # linear_util moved to jax.extend in jax 0.4.17, previous location is deprecated
    try:
        from jax.extend import linear_util
    except ImportError:
        from jax import linear_util
    from jax.api_util import argnums_partial
    from jax._src.api import _jvp, _vjp, api_util
    jax.config.update("jax_enable_x64", True)  # jax by default uses 32 bit floats
except Exception:
    _, err, tb = sys.exc_info()
    if not isinstance(err, ImportError):
        traceback.print_tb(tb)
    jax = None
# regex to check for variable names.
namecheck_rgx = re.compile('[_a-zA-Z][_a-zA-Z0-9]*')
# Names that are not allowed for input or output variables (keywords for options)
_disallowed_varnames = {
    'units', 'shape', 'shape_by_conn', 'run_root_only', 'distributed', 'assembled_jac_type'
}
def _copy_with_ignore(dct, keep, ignore=()):
    """
    Copy the entries in the given dict whose keys are in keep.
    Parameters
    ----------
    dct : dict
        The dictionary to be copied.
    keep : set-like
        Set of keys for entries we want to keep.
    ignore : set or tuple
        Don't issue a warning for these non-keeper keys.
    Returns
    -------
    dict
        A new dict containing 'keep' entries.
    """
    return {k: v for k, v in dct.items() if k in keep and k not in ignore}
def _check_var_name(comp, name):
    match = namecheck_rgx.match(name)
    if match is None or match.group() != name:
        raise NameError(f"{comp.msginfo}: '{name}' is not a valid variable name.")
    if name in _disallowed_varnames:
        raise NameError(f"{comp.msginfo}: cannot use variable name '{name}' because "
                        "it's a reserved keyword.")
[docs]
def jac_forward(fun, argnums, tangents):
    """
    Similar to the jax.jacfwd function but allows specification of the tangent matrix.
    This allows us to generate a compressed jacobian based on coloring.
    Parameters
    ----------
    fun : function
        The function to be differentiated.
    argnums : tuple of int or None
        Specifies which positional args are dynamic.  None means all positional args are dynamic.
    tangents : ndarray
        Array of 1.0's and 0's that is used to compute the value of the jacobian matrix.
    Returns
    -------
    function
        If there are multiple output variables, returns a function that returns rows of the
        jacobian grouped by output variable, e.g., if there were 2 output variables of size 3 and 4,
        the function would return a list with two entries. The first entry would contain the first 3
        rows of J and the second would contain the next 4 rows of J.  If there is only 1 output
        variable, the values returned are grouped by input variable.
    """
    try:
        # Newer JAX versions (>= 0.4.x)
        f = linear_util.wrap_init(fun, debug_info=api_util.debug_info('jac_forward', fun, (), {}))
    except TypeError:
        # Older JAX versions
        f = linear_util.wrap_init(fun)
    if argnums is None:
        def jacfunf(*args):
            return vmap(partial(_jvp, f, args), out_axes=(None, -1))(tangents)[1]
    else:
        def jacfunf(*args):
            f_partial, dyn_args = argnums_partial(f, argnums, args)
            return vmap(partial(_jvp, f_partial, dyn_args), out_axes=(None, -1))(tangents)[1]
    return jacfunf 
[docs]
def jac_reverse(fun, argnums, tangents):
    """
    Similar to the jax.jacrev function but allows specification of the tangent matrix.
    This allows us to generate a compressed jacobian based on coloring.
    Parameters
    ----------
    fun : function
        The function to be differentiated.
    argnums : tuple of int or None
        Specifies which positional args are dynamic.  None means all positional args are dynamic.
    tangents : ndarray
        Array of 1.0's and 0's that is used to compute the value of the jacobian matrix.
    Returns
    -------
    function
        A function that returns rows of the jacobian grouped by function input variable, e.g., if
        there were 3 input variables of size 5 and 7 and 9, the function
        would return a list with 3 entries. The first entry would contain the first 5 columns of J,
        the second the next 7 columns of J, and the third the next 9 columns of J.  Note that for
        implicit systems, the function inputs will contain both inputs and outputs in the context
        of OpenMDAO.
    """
    try:
        # Newer JAX versions (>= 0.4.x)
        f = linear_util.wrap_init(fun, debug_info=api_util.debug_info('jac_reverse', fun, (), {}))
    except TypeError:
        # Older JAX versions
        f = linear_util.wrap_init(fun)
    if argnums is None:
        def jacfunr(*args):
            return vmap(_vjp(f, *args)[1])(tangents)
    else:
        def jacfunr(*args):
            f_partial, dyn_args = argnums_partial(f, argnums, args)
            return vmap(_vjp(f_partial, *dyn_args)[1])(tangents)
    return jacfunr 
[docs]
def jacvec_prod(fun, argnums, invals, tangent):
    """
    Similar to the jvp function but gives back a flat column.
    Note: this is significantly slower (when producing a full jacobian) than jac_forward.
    Parameters
    ----------
    fun : function
        The function to be differentiated.
    argnums : tuple of int or None
        Specifies which positional args are dynamic.  None means all positional args are dynamic.
    invals : tuple of float or ndarray
        Dynamic function input values.
    tangent : ndarray
        Array of 1.0's and 0's that is used to compute a column of the jacobian matrix.
    Returns
    -------
    function
        A function to compute the jacobian vector product.
    """
    try:
        # Newer JAX versions (>= 0.4.x)
        f = linear_util.wrap_init(fun, debug_info=api_util.debug_info('jacvec_prod', fun, (), {}))
    except TypeError:
        # Older JAX versions
        f = linear_util.wrap_init(fun)
    if argnums is not None:
        invals = list(argnums_partial(f, argnums, invals)[1])
    # compute shaped tangents to use later
    sizes = np.array([jnp.size(a) for a in invals])
    inds = np.cumsum(sizes[:-1])
    shaped_tangents = [a.reshape(s.shape) for a, s in zip(np.split(tangent, inds, axis=0), invals)]
    if argnums is None:
        def jvfun(inps):
            return _jvp(f, inps, shaped_tangents)[1]
    else:
        def jvfun(inps):
            f_partial, dyn_args = argnums_partial(f, argnums, inps)
            return _jvp(f_partial, list(dyn_args), shaped_tangents)[1]
    return jvfun 
def _get_tangents(vals, direction, coloring=None, argnums=None, trans=None):
    """
    Return a tuple of tangents values for use with vmap.
    Parameters
    ----------
    vals : list
        List of function input values.
    direction : str
        Derivative computation direction ('fwd' or 'rev').
    coloring : Coloring or None
        If not None, the Coloring object used to compute a compressed tangent array.
    argnums : list of int or None
        Indices of dynamic (differentiable) function args.
    trans : ndarray
        Translation array from jacobian indices into function arg indices.  This is needed
        because OpenMDAO expects ordering to be outputs first, then inputs, but function args
        could be in any order.
    Returns
    -------
    tuple of ndarray or ndarray
        The tangents values to be passed to vmap.
    """
    if argnums is None:
        leaves = vals
    else:
        leaves = [vals[i] for i in argnums]
    sizes = [np.size(a) for a in leaves]
    inds = np.cumsum(sizes[:-1])
    if coloring is None:
        tangent = np.eye(np.sum(sizes))
        if trans is not None:
            tangent = tangent[:, trans]
    else:
        tangent = coloring.tangent_matrix(direction, trans=trans)
    shapes = [tangent.shape[:1] + np.shape(v) for v in leaves]
    tangents = tuple([np.reshape(a, shp) for a, shp in zip(np.split(tangent, inds, axis=1),
                                                           shapes)])
    if len(leaves) == 1:
        tangents = tangents[0]
    return tangents
def _ensure_iter(val):
    """
    Turn the given value into an iterator if it is not already.
    Parameters
    ----------
    val : object
        The value to be iterated over.
    Returns
    -------
    tuple or iterable
    """
    if isinstance(val, tuple):
        return val
    return val,