Source code for openmdao.core.explicitcomponent

"""Define the ExplicitComponent class."""
import inspect

import numpy as np
from types import MethodType


from openmdao.jacobians.dictionary_jacobian import DictionaryJacobian
from openmdao.core.component import Component
from openmdao.vectors.vector import _full_slice
from openmdao.utils.class_util import overrides_method
from openmdao.recorders.recording_iteration_stack import Recording
from openmdao.core.constants import INT_DTYPE, _UNDEFINED
from openmdao.utils.jax_utils import jax, jit, ExplicitCompJaxify, \
    compute_partials as _jax_compute_partials, \
    compute_jacvec_product as _jax_compute_jacvec_product, ReturnChecker, _jax_register_pytree_class


_tuplist = (tuple, list)


[docs]class ExplicitComponent(Component): """ Class to inherit from when all output variables are explicit. Parameters ---------- **kwargs : dict of keyword arguments Keyword arguments that will be mapped into the Component options. Attributes ---------- _has_compute_partials : bool If True, the instance overrides compute_partials. _vjp_hash : int or None Hash value for the last set of inputs to the compute_primal function. _vjp_fun : function or None The vector-Jacobian product function. """
[docs] def __init__(self, **kwargs): """ Store some bound methods so we can detect runtime overrides. """ super().__init__(**kwargs) self._has_compute_partials = overrides_method('compute_partials', self, ExplicitComponent) self.options.undeclare('assembled_jac_type') self._vjp_hash = None self._vjp_fun = None
@property def nonlinear_solver(self): """ Get the nonlinear solver for this system. """ return self._nonlinear_solver @nonlinear_solver.setter def nonlinear_solver(self, solver): """ Raise an exception. """ raise RuntimeError(f"{self.msginfo}: Explicit components don't support nonlinear solvers.") @property def linear_solver(self): """ Get the linear solver for this system. """ return self._linear_solver @linear_solver.setter def linear_solver(self, solver): """ Raise an exception. """ raise RuntimeError(f"{self.msginfo}: Explicit components don't support linear solvers.") def _configure(self): """ Configure this system to assign children settings and detect if matrix_free. """ if self.matrix_free is _UNDEFINED: self.matrix_free = overrides_method('compute_jacvec_product', self, ExplicitComponent) def _jac_wrt_iter(self, wrt_matches=None): """ Iterate over (name, start, end, vec, slice, dist_sizes) for each column var in the jacobian. Parameters ---------- wrt_matches : set or None Only include row vars that are contained in this set. This will determine what the actual offsets are, i.e. the offsets will be into a reduced jacobian containing only the matching columns. Yields ------ str Absolute name of 'wrt' variable. int Starting index. int Ending index. Vector The _inputs vector. slice A full slice. ndarray or None Distributed sizes if var is distributed else None """ start = end = 0 local_ins = self._var_abs2meta['input'] toidx = self._var_allprocs_abs2idx sizes = self._var_sizes['input'] for wrt, meta in self._var_abs2meta['input'].items(): if wrt_matches is None or wrt in wrt_matches: end += meta['size'] vec = self._inputs if wrt in local_ins else None dist_sizes = sizes[:, toidx[wrt]] if meta['distributed'] else None yield wrt, start, end, vec, _full_slice, dist_sizes start = end def _setup_residuals(self): """ Prevent the user from implementing setup_residuals for explicit components. """ if overrides_method('setup_residuals', self, ExplicitComponent): raise RuntimeError(f'{self.msginfo}: Class overrides setup_residuals but ' 'is an ExplicitComponent. setup_residuals may only be ' 'overridden by ImplicitComponents.') def _setup_partials(self): """ Call setup_partials in components. """ super()._setup_partials() if self.matrix_free: return # Note: These declare calls are outside of setup_partials so that users do not have to # call the super version of setup_partials. This is still in the final setup. for out_abs, meta in self._var_abs2meta['output'].items(): size = meta['size'] if size > 0: # ExplicitComponent jacobians have -1 on the diagonal. arange = np.arange(size, dtype=INT_DTYPE) self._subjacs_info[out_abs, out_abs] = { 'rows': arange, 'cols': arange, 'shape': (size, size), 'val': np.full(size, -1.), 'dependent': True, } def _setup_jacobians(self, recurse=True): """ Set and populate jacobian. Parameters ---------- recurse : bool If True, setup jacobians in all descendants. (ignored) """ if self._has_approx and self._use_derivatives: self._set_approx_partials_meta()
[docs] def add_output(self, name, val=1.0, shape=None, units=None, res_units=None, desc='', lower=None, upper=None, ref=1.0, ref0=0.0, res_ref=None, tags=None, shape_by_conn=False, copy_shape=None, compute_shape=None, distributed=None): """ Add an output variable to the component. For ExplicitComponent, res_ref defaults to the value in res unless otherwise specified. Parameters ---------- name : str Name of the variable in this component's namespace. val : float or list or tuple or ndarray The initial value of the variable being added in user-defined units. Default is 1.0. shape : int or tuple or list or None Shape of this variable, only required if val is not an array. Default is None. units : str or None Units in which the output variables will be provided to the component during execution. Default is None, which means it has no units. res_units : str or None Units in which the residuals of this output will be given to the user when requested. Default is None, which means it has no units. desc : str Description of the variable. lower : float or list or tuple or ndarray or None Lower bound(s) in user-defined units. It can be (1) a float, (2) an array_like consistent with the shape arg (if given), or (3) an array_like matching the shape of val, if val is array_like. A value of None means this output has no lower bound. Default is None. upper : float or list or tuple or ndarray or None Upper bound(s) in user-defined units. It can be (1) a float, (2) an array_like consistent with the shape arg (if given), or (3) an array_like matching the shape of val, if val is array_like. A value of None means this output has no upper bound. Default is None. ref : float Scaling parameter. The value in the user-defined units of this output variable when the scaled value is 1. Default is 1. ref0 : float Scaling parameter. The value in the user-defined units of this output variable when the scaled value is 0. Default is 0. res_ref : float Scaling parameter. The value in the user-defined res_units of this output's residual when the scaled value is 1. Default is None, which means residual scaling matches output scaling. tags : str or list of strs User defined tags that can be used to filter what gets listed when calling list_inputs and list_outputs and also when listing results from case recorders. shape_by_conn : bool If True, shape this output to match its connected input(s). copy_shape : str or None If a str, that str is the name of a variable. Shape this output to match that of the named variable. compute_shape : function or None If a function, that function is called to determine the shape of this output. distributed : bool If True, this variable is a distributed variable, so it can have different sizes/values across MPI processes. Returns ------- dict Metadata for added variable. """ if res_ref is None: res_ref = ref return super().add_output(name, val=val, shape=shape, units=units, res_units=res_units, desc=desc, lower=lower, upper=upper, ref=ref, ref0=ref0, res_ref=res_ref, tags=tags, shape_by_conn=shape_by_conn, copy_shape=copy_shape, compute_shape=compute_shape, distributed=distributed)
def _approx_subjac_keys_iter(self): is_output = self._outputs._contains_abs for abs_key, meta in self._subjacs_info.items(): if 'method' in meta and not is_output(abs_key[1]): method = meta['method'] if (method is not None and method in self._approx_schemes): yield abs_key def _compute_wrapper(self): """ Call compute based on the value of the "run_root_only" option. """ with self._call_user_function('compute'): if self._run_root_only(): if self.comm.rank == 0: if self._discrete_inputs or self._discrete_outputs: self.compute(self._inputs, self._outputs, self._discrete_inputs, self._discrete_outputs) else: self.compute(self._inputs, self._outputs) self.comm.bcast([self._outputs.asarray(), self._discrete_outputs], root=0) else: new_outs, new_disc_outs = self.comm.bcast(None, root=0) self._outputs.set_val(new_outs) if new_disc_outs: for name, val in new_disc_outs.items(): self._discrete_outputs[name] = val else: if self._discrete_inputs or self._discrete_outputs: self.compute(self._inputs, self._outputs, self._discrete_inputs, self._discrete_outputs) else: self.compute(self._inputs, self._outputs) def _apply_nonlinear(self): """ Compute residuals. The model is assumed to be in a scaled state. """ outputs = self._outputs residuals = self._residuals with self._unscaled_context(outputs=[outputs], residuals=[residuals]): residuals.set_vec(outputs) # Sign of the residual is minus the sign of the output vector. residuals *= -1.0 self._compute_wrapper() residuals += outputs outputs -= residuals self.iter_count_apply += 1 def _solve_nonlinear(self): """ Compute outputs. The model is assumed to be in a scaled state. """ with Recording(self.pathname + '._solve_nonlinear', self.iter_count, self): with self._unscaled_context(outputs=[self._outputs], residuals=[self._residuals]): self._residuals.set_val(0.0) self._compute_wrapper() # Iteration counter is incremented in the Recording context manager at exit. def _compute_jacvec_product_wrapper(self, inputs, d_inputs, d_resids, mode, discrete_inputs=None): """ Call compute_jacvec_product based on the value of the "run_root_only" option. Parameters ---------- inputs : Vector Nonlinear input vector. d_inputs : Vector Linear input vector. d_resids : Vector Linear residual vector. mode : str Indicates direction of derivative computation, either 'fwd' or 'rev'. discrete_inputs : dict or None Mapping of variable name to discrete value. """ if self._run_root_only(): if self.comm.rank == 0: if discrete_inputs: self.compute_jacvec_product(inputs, d_inputs, d_resids, mode, discrete_inputs) else: self.compute_jacvec_product(inputs, d_inputs, d_resids, mode) if mode == 'fwd': self.comm.bcast(d_resids.asarray(), root=0) else: # rev self.comm.bcast(d_inputs.asarray(), root=0) else: new_vals = self.comm.bcast(None, root=0) if mode == 'fwd': d_resids.set_val(new_vals) else: # rev d_inputs.set_val(new_vals) else: dochk = mode == 'rev' and self._problem_meta['checking'] and self.comm.size > 1 if dochk: nzdresids = self._get_dist_nz_dresids() if discrete_inputs: self.compute_jacvec_product(inputs, d_inputs, d_resids, mode, discrete_inputs) else: self.compute_jacvec_product(inputs, d_inputs, d_resids, mode) if dochk: self._check_consistent_serial_dinputs(nzdresids) def _apply_linear(self, jac, mode, scope_out=None, scope_in=None): """ Compute jac-vec product. The model is assumed to be in a scaled state. Parameters ---------- jac : Jacobian or None If None, use local jacobian, else use jac. mode : str 'fwd' or 'rev'. scope_out : set or None Set of absolute output names in the scope of this mat-vec product. If None, all are in the scope. scope_in : set or None Set of absolute input names in the scope of this mat-vec product. If None, all are in the scope. """ J = self._jacobian if jac is None else jac with self._matvec_context(scope_out, scope_in, mode) as vecs: d_inputs, d_outputs, d_residuals = vecs if not self.matrix_free: # if we're not matrix free, we can skip the rest because # compute_jacvec_product does nothing. # Jacobian and vectors are all scaled, unitless J._apply(self, d_inputs, d_outputs, d_residuals, mode) return # Jacobian and vectors are all unscaled, dimensional with self._unscaled_context(outputs=[self._outputs], residuals=[d_residuals]): # set appropriate vectors to read_only to help prevent user error if mode == 'fwd': d_inputs.read_only = True else: # rev d_residuals.read_only = True try: # handle identity subjacs (output_or_resid wrt itself) if J is None or isinstance(J, DictionaryJacobian): if d_outputs._names: rflat = d_residuals._abs_get_val oflat = d_outputs._abs_get_val subjacs_empty = len(self._subjacs_info) == 0 # 'val' in the code below is a reference to the part of the # output or residual array corresponding to the variable 'v' if mode == 'fwd': for v in d_outputs._names: if subjacs_empty or (v, v) not in self._subjacs_info: val = rflat(v) val -= oflat(v) else: # rev for v in d_outputs._names: if subjacs_empty or (v, v) not in self._subjacs_info: val = oflat(v) val -= rflat(v) # We used to negate the residual here, and then re-negate after the hook with self._call_user_function('compute_jacvec_product'): self._compute_jacvec_product_wrapper(self._inputs, d_inputs, d_residuals, mode, self._discrete_inputs) finally: d_inputs.read_only = d_residuals.read_only = False def _solve_linear(self, mode, scope_out=_UNDEFINED, scope_in=_UNDEFINED): """ Apply inverse jac product. The model is assumed to be in a scaled state. Parameters ---------- mode : str 'fwd' or 'rev'. scope_out : set, None, or _UNDEFINED Outputs relevant to possible lower level calls to _apply_linear on Components. scope_in : set, None, or _UNDEFINED Inputs relevant to possible lower level calls to _apply_linear on Components. """ d_outputs = self._doutputs d_residuals = self._dresiduals if mode == 'fwd': if self._has_resid_scaling: with self._unscaled_context(outputs=[d_outputs], residuals=[d_residuals]): d_outputs.set_vec(d_residuals) else: d_outputs.set_vec(d_residuals) # ExplicitComponent jacobian defined with -1 on diagonal. d_outputs *= -1.0 else: # rev if self._has_resid_scaling: with self._unscaled_context(outputs=[d_outputs], residuals=[d_residuals]): d_residuals.set_vec(d_outputs) else: d_residuals.set_vec(d_outputs) # ExplicitComponent jacobian defined with -1 on diagonal. d_residuals *= -1.0 def _compute_partials_wrapper(self): """ Call compute_partials based on the value of the "run_root_only" option. """ with self._call_user_function('compute_partials'): if self._run_root_only(): if self.comm.rank == 0: if self._discrete_inputs: self.compute_partials(self._inputs, self._jacobian, self._discrete_inputs) else: self.compute_partials(self._inputs, self._jacobian) self.comm.bcast(list(self._jacobian.items()), root=0) else: for key, val in self.comm.bcast(None, root=0): self._jacobian[key] = val else: if self._discrete_inputs: self.compute_partials(self._inputs, self._jacobian, self._discrete_inputs) else: self.compute_partials(self._inputs, self._jacobian) def _linearize(self, jac=None, sub_do_ln=False): """ Compute jacobian / factorization. The model is assumed to be in a scaled state. Parameters ---------- jac : Jacobian or None Ignored. sub_do_ln : bool Flag indicating if the children should call linearize on their linear solvers. """ if self.matrix_free or not (self._has_compute_partials or self._approx_schemes): return self._check_first_linearize() with self._unscaled_context(outputs=[self._outputs], residuals=[self._residuals]): # Computing the approximation before the call to compute_partials allows users to # override FD'd values. for approximation in self._approx_schemes.values(): approximation.compute_approximations(self, jac=self._jacobian) if self._has_compute_partials: # We used to negate the jacobian here, and then re-negate after the hook. self._compute_partials_wrapper()
[docs] def compute(self, inputs, outputs, discrete_inputs=None, discrete_outputs=None): """ Compute outputs given inputs. The model is assumed to be in an unscaled state. An inherited component may choose to either override this function or to define a compute_primal function. Parameters ---------- inputs : Vector Unscaled, dimensional input variables read via inputs[key]. outputs : Vector Unscaled, dimensional output variables read via outputs[key]. discrete_inputs : dict-like or None If not None, dict-like object containing discrete input values. discrete_outputs : dict-like or None If not None, dict-like object containing discrete output values. """ global _tuplist if self.compute_primal is None: return returns = \ self.compute_primal(*self._get_compute_primal_invals(inputs, discrete_inputs)) if not isinstance(returns, _tuplist): returns = (returns,) if not discrete_outputs: outputs.set_vals(returns) else: outputs.set_vals(returns[:outputs.nvars()]) self._discrete_outputs.set_vals(returns[outputs.nvars():])
[docs] def compute_partials(self, inputs, partials, discrete_inputs=None): """ Compute sub-jacobian parts. The model is assumed to be in an unscaled state. Parameters ---------- inputs : Vector Unscaled, dimensional input variables read via inputs[key]. partials : Jacobian Sub-jac components written to partials[output_name, input_name].. discrete_inputs : dict or None If not None, dict containing discrete input values. """ pass
[docs] def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode, discrete_inputs=None): r""" Compute jac-vector product. The model is assumed to be in an unscaled state. If mode is: 'fwd': d_inputs \|-> d_outputs 'rev': d_outputs \|-> d_inputs Parameters ---------- inputs : Vector Unscaled, dimensional input variables read via inputs[key]. d_inputs : Vector See inputs; product must be computed only if var_name in d_inputs. d_outputs : Vector See outputs; product must be computed only if var_name in d_outputs. mode : str Either 'fwd' or 'rev'. discrete_inputs : dict or None If not None, dict containing discrete input values. """ pass
[docs] def is_explicit(self): """ Return True if this is an explicit component. Returns ------- bool True if this is an explicit component. """ return True
def _get_compute_primal_invals(self, inputs, discrete_inputs): """ Yield the inputs expected by the compute_primal method. Parameters ---------- inputs : Vector Unscaled, dimensional input variables Vector. discrete_inputs : dict or None If not None, dict containing discrete input values. Yields ------ any Inputs expected by the compute_primal method. """ yield from inputs.values() if discrete_inputs: yield from discrete_inputs.values() def _get_compute_primal_argnames(self): """ Return the expected argnames for the compute_primal method. Returns ------- list List of argnames expected by the compute_primal method. """ argnames = [] argnames.extend(self._var_rel_names['input']) if self._discrete_inputs: argnames.extend(self._discrete_inputs) return argnames def _setup_jax(self, from_group=False): """ Set up the jax interface for this component. Parameters ---------- from_group : bool If True, this is being called from a Group setup. """ if self.matrix_free is True: self.compute_jacvec_product = MethodType(_jax_compute_jacvec_product, self) else: self.compute_partials = MethodType(_jax_compute_partials, self) self._has_compute_partials = True if self.compute_primal is None: # convert the compute method to a compute_primal method jaxifier = ExplicitCompJaxify(self, verbose=True) if jaxifier.get_self_statics: self.get_self_statics = MethodType(jaxifier.get_self_statics, self) # replace existing compute method with base class method, so that compute_primal # will be called. self.compute = MethodType(ExplicitComponent.compute, self) self.compute_primal = MethodType(jaxifier.compute_primal, self) self._compute_primal_returns_tuple = True else: # check that compute_primal args are in the correct order args = list(inspect.signature(self.compute_primal).parameters) if args and args[0] == 'self': args = args[1:] compargs = self._get_compute_primal_argnames() if args != compargs: raise RuntimeError(f"{self.msginfo}: compute_primal method args {args} don't match " f"the expected args {compargs}.") # determine if the compute_primal method returns a tuple self._compute_primal_returns_tuple = ReturnChecker(self.compute_primal).returns_tuple() if not from_group and self.options['use_jit']: static_argnums = [] idx = len(self._var_rel_names['input']) + 1 static_argnums.extend(range(idx, idx + len(self._discrete_inputs))) self.compute_primal = MethodType(jit(self.compute_primal.__func__, static_argnums=static_argnums), self) _jax_register_pytree_class(self.__class__) def _get_jac_func(self): """ Return the jacobian function for this component. In forward mode, jax.jacfwd is used, and in reverse mode, jax.jacrev is used. The direction is chosen automatically based on the sizes of the inputs and outputs. Returns ------- function The jacobian function. """ # TODO: modify this to use relevance and possibly compile multiple jac functions depending # on DV/response so that we don't compute any derivatives that are always zero. if self._jac_func_ is None: fjax = jax.jacfwd if self.best_partial_deriv_direction() == 'fwd' else jax.jacrev nstatic = len(self._discrete_inputs) wrt_idxs = list(range(1, len(self._var_abs2meta['input']) + 1)) self._jac_func_ = MethodType(fjax(self.compute_primal.__func__, argnums=wrt_idxs), self) if self.options['use_jit']: static_argnums = tuple(range(1 + len(wrt_idxs), 1 + len(wrt_idxs) + nstatic)) self._jac_func_ = MethodType(jit(self._jac_func_.__func__, static_argnums=static_argnums), self) return self._jac_func_