ImplicitFuncComp#

ImplicitFuncComp is a component that provides a shortcut for building an ImplicitComponent based on a python function. That function takes inputs and states as arguments and returns residual values. The function may take some inputs that are non-differentiable and are assumed to be static during the computation of derivatives. These static values may be of any hashable type. All other arguments and return values must be either floats or numpy arrays. The mapping between a state argument and its residual output must be specified in the metadata when the output (state) is added by setting ‘resid’ to the name of the residual.

It may seem confusing to use add_output to specify state variables since the state variables are actually input arguments to the function, but in OpenMDAO’s view of the world, states are outputs so we use add_output to specify them. Also, using the metadata to specify which input arguments are actually states gives more flexibility in terms of how the function arguments are ordered. For example, if it’s desirable for a function to be passable to scipy.optimize.newton, then the function’s arguments can be ordered with the states first, followed by the inputs, in order to match the order expected by scipy.optimize.newton.

The add_output function is part of the Function Metadata API. You use this API to specify various metadata that OpenMDAO needs in order to properly configure a fully functional implicit component. You should read and understand the Function Metadata API before you continue with this section.

ImplicitFuncComp Options#

OptionDefaultAcceptable ValuesAcceptable TypesDescription
always_optFalse[True, False]['bool']If True, force nonlinear operations on this component to be included in the optimization loop even if this component is not relevant to the design variables and responses.
assembled_jac_typecsc['csc', 'dense']N/ALinear solver(s) in this group or implicit component, if using an assembled jacobian, will use this type.
distributedFalse[True, False]['bool']If True, set all variables in this component as distributed across multiple processes
run_root_onlyFalse[True, False]['bool']If True, call compute, compute_partials, linearize, apply_linear, apply_nonlinear, and compute_jacvec_product only on rank 0 and broadcast the results to the other ranks.
use_jaxFalse[True, False]['bool']If True, use jax to compute derivatives.
use_jitFalse[True, False]['bool']If True, attempt to use jit on the function. This is ignored if use_jax is False.

ImplicitFuncComp Constructor#

The call signature for the ImplicitFuncComp constructor is:

ImplicitFuncComp.__init__(apply_nonlinear, solve_nonlinear=None, linearize=None, solve_linear=None, **kwargs)[source]

Initialize attributes.

ImplicitFuncComp Example: A simple implicit function component#

The simplest implicit function component requires the definition of a function that takes inputs and states as arguments and returns residual values. This function maps to the apply_nonlinear method in the OpenMDAO component API. Here’s an example:

import openmdao.api as om
import openmdao.func_api as omf

def apply_nl(a, b, c, x):  # inputs a, b, c and state x
    R_x = a * x ** 2 + b * x + c
    return R_x

f = (omf.wrap(apply_nl)
        .add_output('x', resid='R_x', val=0.0)
        .declare_partials(of='*', wrt='*', method='cs')
        )

p = om.Problem()
p.model.add_subsystem('comp', om.ImplicitFuncComp(f))

p.model.nonlinear_solver = om.NewtonSolver(solve_subsystems=False, iprint=0)

# need this since comp is implicit and doesn't have a solve_linear
p.model.linear_solver = om.DirectSolver()

p.setup()

p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()

ImplicitFuncComp Example: Partial Derivatives#

All nonzero partial derivatives for an ImplicitFuncComp must be declared when the function is wrapped. Otherwise, OpenMDAO will assume that all partial derivatives for that component are zero. For example, in the previous model we didn’t bother to declare partials because we weren’t computing derivatives. Now we’ll compute the total derivatives, so we need to declare the partials. Also, because our implicit function component does not define its own linearize function, we’ll specify a method of fd or cs when we declare our partials, so they’ll be computed using finite differencing or complex step. Finally, because our implicit component doesn’t define its own solve_linear function, we have to specify a linear solver for our component.

def apply_nl(a, b, c, x):
    R_x = a * x ** 2 + b * x + c
    return R_x

def solve_nonlinear(a, b, c, x):
    x = (-b + (b ** 2 - 4 * a * c) ** 0.5) / (2 * a)
    return x

f = (omf.wrap(apply_nl)
        .add_output('x', resid='R_x', val=0.0)
        .declare_partials(of='*', wrt='*', method='cs')
        )

p = om.Problem()
comp = p.model.add_subsystem('comp', om.ImplicitFuncComp(f, solve_nonlinear=solve_nonlinear))

# need this since comp is implicit and doesn't have a solve_linear
comp.linear_solver = om.DirectSolver()

p.setup()

p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()


J = p.compute_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'])
print(('comp.x', 'comp.a'), J['comp.x', 'comp.a'])
print(('comp.x', 'comp.b'), J['comp.x', 'comp.b'])
print(('comp.x', 'comp.c'), J['comp.x', 'comp.c'])
('comp.x', 'comp.a') [[-2.25]]
('comp.x', 'comp.b') [[-0.75]]
('comp.x', 'comp.c') [[-0.25]]

ImplicitFuncComp Example: Specifying linearize, solve_linear, and solve_nonlinear functions#

The following implicit function component specifies linearize, solve_linear, and solve_nonlinear functions, so no external linear or nonlinear solvers are required to compute outputs or derivatives.

def apply_nl(a, b, c, x):
    R_x = a * x ** 2 + b * x + c
    return R_x

def solve_nonlinear(a, b, c, x):
    x = (-b + (b ** 2 - 4 * a * c) ** 0.5) / (2 * a)
    return x

def linearize(a, b, c, x, partials):
    partials['x', 'a'] = x ** 2
    partials['x', 'b'] = x
    partials['x', 'c'] = 1.0
    partials['x', 'x'] = 2 * a * x + b

    inv_jac = 1.0 / (2 * a * x + b)
    return inv_jac

def solve_linear(d_x, mode, inv_jac):
    if mode == 'fwd':
        d_x = inv_jac * d_x
        return d_x
    elif mode == 'rev':
        dR_x = inv_jac * d_x
        return dR_x

f = (omf.wrap(apply_nl)
        .add_output('x', resid='R_x', val=0.0)
        .declare_partials(of='*', wrt='*')
        )

p = om.Problem()
p.model.add_subsystem('comp', om.ImplicitFuncComp(f,
                                                  solve_nonlinear=solve_nonlinear,
                                                  solve_linear=solve_linear, 
                                                  linearize=linearize))
p.setup()

p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()

J = p.compute_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'])
print(('comp.x', 'comp.a'), J['comp.x', 'comp.a'])
print(('comp.x', 'comp.b'), J['comp.x', 'comp.b'])
print(('comp.x', 'comp.c'), J['comp.x', 'comp.c'])
('comp.x', 'comp.a') [[-2.25]]
('comp.x', 'comp.b') [[-0.75]]
('comp.x', 'comp.c') [[-0.25]]

Using jax to Compute Partial Derivatives#

If the function used to instantiate the ExplicitFuncComp declares partials or coloring that use method='jax', or if the component’s use_jax option is set, then the jax AD package will be used to compute all of the component’s derivatives. Currently it’s not possible to mix jax with finite difference methods (‘cs’ and ‘fd’) in the same component.

Note that jax is an optional OpenMDAO dependency, but you can install it manually.

pip install jax
pip install jaxlib

should work in most cases.

To activate jax’s just-in-time compilation capability, set the use_jit option on the component. For example:

p = om.Problem()
p.model.add_subsystem('comp', om.ImplicitFuncComp(f,
                                                  solve_nonlinear=solve_nonlinear,
                                                  solve_linear=solve_linear, 
                                                  linearize=linearize,
                                                  use_jit=True))
p.setup()

p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()

J = p.compute_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'])
print(('comp.x', 'comp.a'), J['comp.x', 'comp.a'])
print(('comp.x', 'comp.b'), J['comp.x', 'comp.b'])
print(('comp.x', 'comp.c'), J['comp.x', 'comp.c'])
('comp.x', 'comp.a') [[-2.25]]
('comp.x', 'comp.b') [[-0.75]]
('comp.x', 'comp.c') [[-0.25]]