ExplicitFuncComp#
ExplicitFuncComp
is a component that provides a shortcut for building an ExplicitComponent based
on a python function. The function must have one or more differentiable arguments, which can be either
floats or numpy arrays, and zero or more arguments that are treated as non-differentiable and are
assumed to remain static during the computation of derivatives. If jax is used, then any non-differentiable arguments must be hashable. The function must also return either a single float or numpy array or a tuple of such values, where each entry in the tuple represents a different output of the OpenMDAO component. The ExplicitFuncComp automatically takes care of all of the component API methods, so you just need to instantiate it with a function. In most cases that function will need additional metadata that you can add using the Function Metadata API. You should read and understand the Function Metadata API before you continue with this section.
ExplicitFuncComp Options#
Option | Default | Acceptable Values | Acceptable Types | Description |
---|---|---|---|---|
always_opt | False | [True, False] | ['bool'] | If True, force nonlinear operations on this component to be included in the optimization loop even if this component is not relevant to the design variables and responses. |
derivs_method | N/A | ['jax', 'cs', 'fd', None] | N/A | The method to use for computing derivatives |
distributed | False | [True, False] | ['bool'] | If True, set all variables in this component as distributed across multiple processes |
run_root_only | False | [True, False] | ['bool'] | If True, call compute, compute_partials, linearize, apply_linear, apply_nonlinear, and compute_jacvec_product only on rank 0 and broadcast the results to the other ranks. |
use_jit | True | [True, False] | ['bool'] | If True, attempt to use jit on compute_primal, assuming jax or some other AD package is active. |
ExplicitFuncComp Constructor#
The call signature for the ExplicitFuncComp
constructor is:
- ExplicitFuncComp.__init__(compute, compute_partials=None, **kwargs)[source]
Initialize attributes.
ExplicitFuncComp Example: Simple#
For example, here is a simple component that takes the input and adds one to it.
import openmdao.api as om
prob = om.Problem()
model = prob.model
def func(x=2.0):
y = x + 1.
return y
model.add_subsystem('comp', om.ExplicitFuncComp(func))
prob.setup()
prob.run_model()
print(prob.get_val('comp.y'))
3.0
ExplicitFuncComp Example: Arrays#
You can declare an ExplicitFuncComp with arrays for inputs and outputs. In the case of inputs,
you must either provide default array values or you must set their ‘shape’ metadata correctly using
the Function Metadata API. For outputs you must provide ‘shape’
metadata as well. In the example below, the input shape of x
is set via the function default value
and the output y
has its shape set via the add_output
method of the Function Metadata API.
import numpy as np
import openmdao.func_api as omf
prob = om.Problem()
model = prob.model
def func(x=np.array([1., 2., 3.])):
y = x[:2]
return y
f = omf.wrap(func).add_output('y', shape=2)
model.add_subsystem('comp', om.ExplicitFuncComp(f))
prob.setup()
prob.run_model()
print(prob.get_val('comp.y'))
[1. 2.]
ExplicitFuncComp Example: Variable Properties#
You can also declare properties like ‘units’ on the inputs and outputs. In this example we declare all our inputs to be inches to trigger conversion from a variable expressed in feet in one connection source.
prob = om.Problem()
model = prob.model
def func(x, y):
z = x + y
return z
f = omf.wrap(func).defaults(units='inch')
model.add_subsystem('comp', om.ExplicitFuncComp(f))
prob.setup()
prob.set_val('comp.x', 12.0, units='inch')
prob.set_val('comp.y', 1.0, units='ft')
prob.run_model()
print(prob.get_val('comp.z'))
24.0
ExplicitFuncComp Example: Partial Derivatives#
All nonzero partial derivatives for an ExplicitFuncComp must be declared when the function is wrapped. Otherwise, OpenMDAO will assume that all partial derivatives for that component are zero. For example:
def func(a=2.0, b=3.0):
x = 2. * a
y = b - 1.0 / 3.0
return x, y
f = (omf.wrap(func)
.defaults(method='cs')
.declare_partials(of='x', wrt='a')
.declare_partials(of='y', wrt='b'))
p = om.Problem()
p.model.add_subsystem('comp', om.ExplicitFuncComp(f))
p.setup()
p.run_model()
J = p.compute_totals(of=['comp.x', 'comp.y'], wrt=['comp.a', 'comp.b'])
print(J['comp.x', 'comp.a'][0][0])
print(J['comp.y', 'comp.b'][0][0])
2.0
1.0
ExplicitFuncComp Example: Sparse Partials#
If you know that some of the partials are sparse, then you should declare them as sparse in order to get the best possible performance when computing derivatives for your component. Here’s an example of a function with sparse, in this case diagonal, partials:
def func(x=np.ones(5), y=np.ones(5)):
z = x * y
return z
f = (omf.wrap(func)
.add_output('z', shape=5)
.declare_partials(of='z', wrt=['x', 'y'], method='cs', rows=np.arange(5), cols=np.arange(5)))
ExplicitFuncComp Example: Default metadata values#
Metadata that can apply to all the variables in the component are shape and units.
These can be set via the defaults
method of the Function Metadata API.
In the following example the variables all share the same shape and units.
import numpy as np
def func(x=np.ones(5)):
y = 2. * x[2]
return y
f = omf.wrap(func).defaults(shape=1, units='m')
prob = om.Problem()
prob.model.add_subsystem('comp', om.ExplicitFuncComp(f))
prob.setup()
prob.set_val('comp.x', [100., 200., 300., 400., 500.], units='cm')
prob.run_model()
print(prob.get_val('comp.y'))
[6.]
Using jax to Compute Partial Derivatives#
If the function used to instantiate the ExplicitFuncComp declares partials or coloring that use method='jax'
, or if the component’s derivs_method
option is set to jax
, then the jax AD package will be used to compute all of the component’s derivatives. Currently it’s not possible to mix jax with finite difference methods (‘cs’ and ‘fd’) in the same component.
Note that jax
is not an OpenMDAO dependency that is installed by default, so you’ll have to install it by issuing one of the following commands at your operating system command prompt:
pip install jax jaxlib
pip install openmdao[jax]
pip install openmdao[all]
To activate jax’s just-in-time compilation capability, set the use_jit
option on the component. For example:
prob = om.Problem()
prob.model.add_subsystem('comp', om.ExplicitFuncComp(f, use_jit=True))
prob.setup()
prob.set_val('comp.x', [100., 200., 300., 400., 500.], units='cm')
prob.run_model()
print(prob.get_val('comp.y'))
[6.]
Using the compute_partials Option#
If you know the partial differential equations of your component, you can specify them using the compute_partials
option. Below is an example demonstrating this capability:
def J_func(x, y, z, J):
J['foo', 'x'] = -3*np.log(z)/(3*x+2*y)**2
J['foo', 'y'] = -2*np.log(z)/(3*x+2*y)**2
J['bar', 'x'] = 2.*np.ones(4)
J['bar', 'y'] = np.ones(4)
J['foo', 'z'] = 1/(z*(3*x+2*y))
def func(x=np.zeros(4), y=np.ones(4), z=3):
foo = np.log(z)/(3*x+2*y)
bar = 2.*x + y
return foo, bar
f = (omf.wrap(func)
.defaults(units='m')
.add_output('foo', units='1/m', shape=4)
.add_output('bar', shape=4)
.declare_partials(of='foo', wrt=('x', 'y'), rows=np.arange(4), cols=np.arange(4))
.declare_partials(of='foo', wrt='z')
.declare_partials(of='bar', wrt=('x', 'y'), rows=np.arange(4), cols=np.arange(4)))
prob = om.Problem()
prob.model.add_subsystem('comp', om.ExplicitFuncComp(f, compute_partials=J_func))
prob.setup(force_alloc_complex=True)
prob.run_model()
print(prob.get_val('comp.foo'))
print(prob.get_val('comp.bar'))
[0.54930614 0.54930614 0.54930614 0.54930614]
[1. 1. 1. 1.]