Source code for openmdao.components.explicit_func_comp
"""Define the ExplicitFuncComp class."""importsysimporttracebackimportnumpyasnpfromopenmdao.core.explicitcomponentimportExplicitComponentimportopenmdao.func_apiasomffromopenmdao.components.func_comp_commonimport_check_var_name,_copy_with_ignore, \
jac_forward,jac_reverse,_get_tangents,_ensure_iterfromopenmdao.utils.array_utilsimportshape_to_lenfromopenmdao.utils.om_warningsimportissue_warningtry:importjaxfromjaximportjitjax.config.update("jax_enable_x64",True)# jax by default uses 32 bit floatsexceptException:_,err,tb=sys.exc_info()ifnotisinstance(err,ImportError):traceback.print_tb(tb)jax=NoneifjaxisnotNone:try:fromjaximportArrayasJaxArrayexceptImportError:# versions of jax before 0.3.18 do not have the jax.Array base classraiseRuntimeError("An unsupported version of jax is installed. ""OpenMDAO requires 'jax>=4.0' and 'jaxlib>=4.0'. ""Try 'pip install openmdao[jax]' with Python>=3.8.")
[docs]classExplicitFuncComp(ExplicitComponent):""" A component that wraps a python function. Parameters ---------- compute : function The function to be wrapped by this Component. compute_partials : function or None If not None, call this function when computing partials. **kwargs : named args Args passed down to ExplicitComponent. Attributes ---------- _compute : callable The function wrapper used by this component. _compute_jax : callable Function decorated to ensure use of jax numpy. _compute_partials : function or None If not None, call this function when computing partials. _tangents : tuple Tuple of parts of the tangent matrix cached for jax derivative computation. _tangent_direction : str Direction of the last tangent computation. """
[docs]def__init__(self,compute,compute_partials=None,**kwargs):""" Initialize attributes. """super().__init__(**kwargs)self._compute=omf.wrap(compute)# in case we're doing jit, force setup of wrapped func because we compute output shapes# during setup and that won't work on a jit compiled functionifself._compute._call_setup:self._compute._setup()ifself._compute._use_jax:self.options['derivs_method']='jax'ifself.options['derivs_method']=='jax':ifjaxisNone:raiseRuntimeError(f"{self.msginfo}: jax is not installed. ""Try 'pip install openmdao[jax]' with Python>=3.8.")self._compute_jax=omf.jax_decorate(self._compute._f)self._tangents=Noneself._tangent_direction=Noneself._compute_partials=compute_partialsifself.options['derivs_method']=='jax'andself.options['use_jit']:static_argnums=[ifori,minenumerate(self._compute._inputs.values())if'is_option'inm]try:self._compute_jax=jit(self._compute_jax,static_argnums=static_argnums)exceptExceptionaserr:issue_warning(f"{self.msginfo}: failed jit compile of compute function: {err}. ""Falling back to using non-jitted function.")
@propertydef_mode(self):""" Return the current system mode. Returns ------- str The current system mode, 'fwd' or 'rev'. """returnself.best_partial_deriv_direction()
[docs]defsetup(self):""" Define out inputs and outputs. """optignore={'is_option'}use_jax=self.options['derivs_method']=='jax'andjaxisnotNoneforname,metainself._compute.get_input_meta():_check_var_name(self,name)if'is_option'inmetaandmeta['is_option']:kwargs=_copy_with_ignore(meta,omf._allowed_declare_options_args,ignore=optignore)self.options.declare(name,**kwargs)else:kwargs=omf._filter_dict(meta,omf._allowed_add_input_args)ifuse_jax:# make sure internal openmdao values are numpy arrays and not jax Arraysself._dev_arrays_to_np_arrays(kwargs)self.add_input(name,**kwargs)fori,(name,meta)inenumerate(self._compute.get_output_meta()):_check_var_name(self,name)kwargs=_copy_with_ignore(meta,omf._allowed_add_output_args,ignore=('resid',))ifuse_jax:# make sure internal openmdao values are numpy arrays and not jax Arraysself._dev_arrays_to_np_arrays(kwargs)self.add_output(name,**kwargs)
def_setup_jax(self):# TODO: this is here to prevent the ExplicitComponent base class from trying to do its# own jax setup if derivs_method is 'jax'. We should probably refactor this...passdef_dev_arrays_to_np_arrays(self,meta):if'val'inmeta:ifisinstance(meta['val'],JaxArray):meta['val']=np.asarray(meta['val'])def_linearize(self,jac=None,sub_do_ln=False):""" Compute jacobian / factorization. The model is assumed to be in a scaled state. Parameters ---------- jac : Jacobian or None Ignored. sub_do_ln : bool Flag indicating if the children should call linearize on their linear solvers. """ifself.options['derivs_method']=='jax':ifself._mode!=self._tangent_direction:# force recomputation of coloring and tangentsself._first_call_to_linearize=Trueself._tangents=Noneself._check_first_linearize()self._jax_linearize()else:super()._linearize(jac,sub_do_ln)def_jax_linearize(self):""" Compute the jacobian using jax. This updates self._jacobian. """inames=list(self._compute.get_input_names())# argnums specifies which position args are to be differentiatedargnums=[ifori,minenumerate(self._compute._inputs.values())if'is_option'notinm]# keep this around for use locally even if we pass None as argnums to jaxargidxs=argnumsiflen(argnums)==len(inames):argnums=None# speedup if there are no static argsosize=len(self._outputs)isize=len(self._inputs)invals=list(self._func_values(self._inputs))coloring=self._coloring_info.coloringfunc=self._compute_jaxifself._mode=='rev':# use reverse mode to compute derivsoutvals=tuple(self._outputs.values())tangents=self._get_tangents(outvals,'rev',coloring)ifcoloringisNone:j=np.empty((osize,isize),dtype=float)cstart=cend=0fori,ainzip(argidxs,jac_reverse(func,argnums,tangents)(*invals)):ifisinstance(invals[i],np.ndarray):cend+=invals[i].sizeelse:# must be a scalarcend+=1a=np.asarray(a)ifa.ndim<2:j[:,cstart:cend]=a.reshape((a.size,1))else:j[:,cstart:cend]=a.reshape((a.shape[0],cend-cstart))cstart=cendelse:j=[np.asarray(a).reshape((a.shape[0],shape_to_len(a.shape[1:])))forainjac_reverse(func,argnums,tangents)(*invals)]j=coloring.expand_jac(np.hstack(j),'rev')else:tangents=self._get_tangents(invals,'fwd',coloring,argnums)ifcoloringisNone:j=np.empty((osize,isize),dtype=float)start=end=0forainjac_forward(func,argnums,tangents)(*invals):a=np.asarray(a)ifa.ndim<2:a=a.reshape((1,a.size))else:a=a.reshape((shape_to_len(a.shape[:-1]),a.shape[-1]))end+=a.shape[0]ifosize==1:j[0,start:end]=aelse:j[start:end,:]=astart=endelse:j=[np.asarray(a).reshape((shape_to_len(a.shape[:-1]),a.shape[-1]))forainjac_forward(func,argnums,tangents)(*invals)]j=coloring.expand_jac(np.vstack(j),'fwd')self._jacobian.set_dense_jac(self,j)def_get_tangents(self,vals,direction,coloring=None,argnums=None):""" Return a tuple of tangents values for use with vmap. Parameters ---------- vals : list List of function input values. direction : str Derivative computation direction ('fwd' or 'rev'). coloring : Coloring or None If not None, the Coloring object used to compute a compressed tangent array. argnums : list of int or None Indices of dynamic (differentiable) function args. Returns ------- tuple of ndarray or ndarray The tangents values to be passed to vmap. """ifself._tangentsisNone:self._tangents=_get_tangents(vals,direction,coloring,argnums)self._tangent_direction=directionreturnself._tangents
[docs]defcompute(self,inputs,outputs):""" Compute the result of calling our function with the given inputs. Parameters ---------- inputs : Vector Unscaled, dimensional input variables. outputs : Vector Unscaled, dimensional output variables. """outputs.set_vals(_ensure_iter(self._compute(*self._func_values(inputs))))
def_setup_partials(self):""" Check that all partials are declared. """forkwargsinself._compute.get_declare_partials():self.declare_partials(**kwargs)kwargs=self._compute.get_declare_coloring()ifkwargsisnotNone:self.declare_coloring(**kwargs)super()._setup_partials()
[docs]defcompute_partials(self,inputs,partials):""" Compute sub-jacobian parts. The model is assumed to be in an unscaled state. Parameters ---------- inputs : Vector Unscaled, dimensional input variables read via inputs[key]. partials : Jacobian Sub-jac components written to partials[output_name, input_name]. """ifself._compute_partialsisNone:returnself._compute_partials(*self._func_values(inputs),partials)
def_func_values(self,inputs):""" Yield current function input args. Parameters ---------- inputs : Vector The input vector. Yields ------ object Value of current function input variable. """inps=inputs.values()forname,metainself._compute._inputs.items():if'is_option'inmeta:yieldself.options[name]else:yieldnext(inps)def_compute_coloring(self,recurse=False,**overrides):""" Compute a coloring of the partial jacobian. This assumes that the current System is in a proper state for computing derivatives. It just calls the base class version and then resets the tangents so that after coloring a new set of compressed tangents values can be computed. Parameters ---------- recurse : bool If True, recurse from this system down the system hierarchy. Whenever a group is encountered that has specified its coloring metadata, we don't recurse below that group unless that group has a subsystem that has a nonlinear solver that uses gradients. **overrides : dict Any args that will override either default coloring settings or coloring settings resulting from an earlier call to declare_coloring. Returns ------- list of Coloring The computed colorings. """ret=super()._compute_coloring(recurse,**overrides)self._tangents=None# reset to compute new colored tangents laterreturnret