ExplicitFuncComp#

ExplicitFuncComp is a component that provides a shortcut for building an ExplicitComponent based on a python function. The function must have one or more differentiable arguments, which can be either floats or numpy arrays, and zero or more arguments that are treated as non-differentiable and are assumed to remain static during the computation of derivatives. If jax is used, then any non-differentiable arguments must be hashable. The function must also return either a single float or numpy array or a tuple of such values, where each entry in the tuple represents a different output of the OpenMDAO component. The ExplicitFuncComp automatically takes care of all of the component API methods, so you just need to instantiate it with a function. In most cases that function will need additional metadata that you can add using the Function Metadata API. You should read and understand the Function Metadata API before you continue with this section.

ExplicitFuncComp Options#

OptionDefaultAcceptable ValuesAcceptable TypesDescription
always_optFalse[True, False]['bool']If True, force nonlinear operations on this component to be included in the optimization loop even if this component is not relevant to the design variables and responses.
distributedFalse[True, False]['bool']If True, set all variables in this component as distributed across multiple processes
run_root_onlyFalse[True, False]['bool']If True, call compute, compute_partials, linearize, apply_linear, apply_nonlinear, and compute_jacvec_product only on rank 0 and broadcast the results to the other ranks.
use_jaxFalse[True, False]['bool']If True, use jax to compute derivatives.
use_jitFalse[True, False]['bool']If True, attempt to use jit on the function. This is ignored if use_jax is False.

ExplicitFuncComp Constructor#

The call signature for the ExplicitFuncComp constructor is:

ExplicitFuncComp.__init__(compute, compute_partials=None, **kwargs)[source]

Initialize attributes.

ExplicitFuncComp Example: Simple#

For example, here is a simple component that takes the input and adds one to it.

import openmdao.api as om

prob = om.Problem()
model = prob.model

def func(x=2.0):
    y = x + 1.
    return y

model.add_subsystem('comp', om.ExplicitFuncComp(func))

prob.setup()

prob.run_model()

print(prob.get_val('comp.y'))
3.0

ExplicitFuncComp Example: Arrays#

You can declare an ExplicitFuncComp with arrays for inputs and outputs. In the case of inputs, you must either provide default array values or you must set their ‘shape’ metadata correctly using the Function Metadata API. For outputs you must provide ‘shape’ metadata as well. In the example below, the input shape of x is set via the function default value and the output y has its shape set via the add_output method of the Function Metadata API.

import numpy as np
import openmdao.func_api as omf

prob = om.Problem()
model = prob.model

def func(x=np.array([1., 2., 3.])):
    y = x[:2]
    return y

f = omf.wrap(func).add_output('y', shape=2)

model.add_subsystem('comp', om.ExplicitFuncComp(f))

prob.setup()

prob.run_model()

print(prob.get_val('comp.y'))
[1. 2.]

ExplicitFuncComp Example: Variable Properties#

You can also declare properties like ‘units’ on the inputs and outputs. In this example we declare all our inputs to be inches to trigger conversion from a variable expressed in feet in one connection source.

prob = om.Problem()
model = prob.model

def func(x, y):
    z = x + y
    return z

f = omf.wrap(func).defaults(units='inch')

model.add_subsystem('comp', om.ExplicitFuncComp(f))

prob.setup()

prob.set_val('comp.x', 12.0, units='inch')
prob.set_val('comp.y', 1.0, units='ft')

prob.run_model()

print(prob.get_val('comp.z'))
24.0

ExplicitFuncComp Example: Partial Derivatives#

All nonzero partial derivatives for an ExplicitFuncComp must be declared when the function is wrapped. Otherwise, OpenMDAO will assume that all partial derivatives for that component are zero. For example:

def func(a=2.0, b=3.0):
    x = 2. * a
    y = b - 1.0 / 3.0
    return x, y

f = (omf.wrap(func)
        .defaults(method='cs')
        .declare_partials(of='x', wrt='a')
        .declare_partials(of='y', wrt='b'))

p = om.Problem()
p.model.add_subsystem('comp', om.ExplicitFuncComp(f))
p.setup()
p.run_model()
J = p.compute_totals(of=['comp.x', 'comp.y'], wrt=['comp.a', 'comp.b'])
print(J['comp.x', 'comp.a'][0][0])
print(J['comp.y', 'comp.b'][0][0])
2.0
1.0

ExplicitFuncComp Example: Sparse Partials#

If you know that some of the partials are sparse, then you should declare them as sparse in order to get the best possible performance when computing derivatives for your component. Here’s an example of a function with sparse, in this case diagonal, partials:

def func(x=np.ones(5), y=np.ones(5)):
    z = x * y
    return z

f = (omf.wrap(func)
        .add_output('z', shape=5)
        .declare_partials(of='z', wrt=['x', 'y'], method='cs', rows=np.arange(5), cols=np.arange(5)))

ExplicitFuncComp Example: Default metadata values#

Metadata that can apply to all the variables in the component are shape and units. These can be set via the defaults method of the Function Metadata API. In the following example the variables all share the same shape and units.

import numpy as np

def func(x=np.ones(5)):
    y = 2. * x[2]
    return y

f = omf.wrap(func).defaults(shape=1, units='m')

prob = om.Problem()

prob.model.add_subsystem('comp', om.ExplicitFuncComp(f))

prob.setup()

prob.set_val('comp.x', [100., 200., 300., 400., 500.], units='cm')

prob.run_model()

print(prob.get_val('comp.y'))
[6.]

Using jax to Compute Partial Derivatives#

If the function used to instantiate the ExplicitFuncComp declares partials or coloring that use method='jax', or if the component’s use_jax option is set, then the jax AD package will be used to compute all of the component’s derivatives. Currently it’s not possible to mix jax with finite difference methods (‘cs’ and ‘fd’) in the same component.

Note that jax is not an OpenMDAO dependency, so you’ll have to install it manually.

pip install jax
pip install jaxlib

should work in most cases.

To activate jax’s just-in-time compilation capability, set the use_jit option on the component. For example:

prob = om.Problem()

prob.model.add_subsystem('comp', om.ExplicitFuncComp(f, use_jit=True))

prob.setup()

prob.set_val('comp.x', [100., 200., 300., 400., 500.], units='cm')

prob.run_model()

print(prob.get_val('comp.y'))
[6.]

Using the compute_partials Option#

If you know the partial differential equations of your component, you can specify them using the compute_partials option. Below is an example demonstrating this capability:

def J_func(x, y, z, J):
    J['foo', 'x'] = -3*np.log(z)/(3*x+2*y)**2
    J['foo', 'y'] = -2*np.log(z)/(3*x+2*y)**2

    J['bar', 'x'] = 2.*np.ones(4)
    J['bar', 'y'] = np.ones(4)

    J['foo', 'z'] = 1/(z*(3*x+2*y))

def func(x=np.zeros(4), y=np.ones(4), z=3):
    foo = np.log(z)/(3*x+2*y)
    bar = 2.*x + y
    return foo, bar

f = (omf.wrap(func)
        .defaults(units='m')
        .add_output('foo', units='1/m', shape=4)
        .add_output('bar', shape=4)
        .declare_partials(of='foo', wrt=('x', 'y'), rows=np.arange(4), cols=np.arange(4))
        .declare_partials(of='foo', wrt='z')
        .declare_partials(of='bar', wrt=('x', 'y'), rows=np.arange(4), cols=np.arange(4)))

prob = om.Problem()
prob.model.add_subsystem('comp', om.ExplicitFuncComp(f, compute_partials=J_func))
prob.setup(force_alloc_complex=True)
prob.run_model()

print(prob.get_val('comp.foo'))
print(prob.get_val('comp.bar'))
[0.54930614 0.54930614 0.54930614 0.54930614]
[1. 1. 1. 1.]