# ImplicitFuncComp¶

ImplicitFuncComp is a component that provides a shortcut for building an ImplicitComponent based on a python function. That function takes inputs and states as arguments and returns residual values. The mapping between a state and its residual output must be specified in the metadata when the output (state) is added by setting ‘resid’ to the name of the residual.

It may seem confusing to use add_output to specify state variables since the state variables are actually input arguments to the function, but in OpenMDAO’s view of the world, states are outputs so we use add_output to specify them. Also, using the metadata to specify which input arguments are actually states gives more flexibility in terms of how the function arguments are ordered. For example, if it’s desirable for a function to be passable to scipy.optimize.newton, then the function’s arguments can be ordered with the states first, followed by the inputs, in order to match the order expected by scipy.optimize.newton.

The add_output function is part of the Function Metadata API. You use this API to specify various metadata that OpenMDAO needs in order to properly configure a fully functional implicit component. You should read and understand the Function Metadata API before you continue with this section.

## ImplicitFuncComp Options¶

Option Default Acceptable Values Acceptable Types Description Deprecation
assembled_jac_typecsc ['csc', 'dense'] N/A Linear solver(s) in this group or implicit component, if using an assembled jacobian, will use this type. N/A
distributed False [True, False] ['bool'] True if ALL variables in this component are distributed across multiple processes. The 'distributed' option has been deprecated. Individual inputs and outputs should be set as distributed instead, using calls to add_input() or add_output().
run_root_only False [True, False] ['bool'] If True, call compute/compute_partials/linearize/apply_linear/apply_nonlinear/compute_jacvec_product only on rank 0 and broadcast the results to the other ranks.N/A
use_jax False [True, False] ['bool'] If True, use jax to compute derivatives. N/A
use_jit False [True, False] ['bool'] If True, attempt to use jit on the function. This is ignored if use_jax is False. N/A

## ImplicitFuncComp Constructor¶

The call signature for the ImplicitFuncComp constructor is:

ImplicitFuncComp.__init__(apply_nonlinear, solve_nonlinear=None, linearize=None, solve_linear=None, **kwargs)[source]

Initialize attributes.

## ImplicitFuncComp Example: A simple implicit function component¶

The simplest implicit function component requires the definition of a function that takes inputs and states as arguments and returns residual values. This function maps to the apply_nonlinear method in the OpenMDAO component API. Here’s an example:

import openmdao.api as om
import openmdao.func_api as omf

def apply_nl(a, b, c, x):  # inputs a, b, c and state x
R_x = a * x ** 2 + b * x + c
return R_x

f = (omf.wrap(apply_nl)
.declare_partials(of='*', wrt='*', method='cs')
)

p = om.Problem()

p.model.nonlinear_solver = om.NewtonSolver(solve_subsystems=False, iprint=0)

# need this since comp is implicit and doesn't have a solve_linear
p.model.linear_solver = om.DirectSolver()

p.setup()

p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()


## ImplicitFuncComp Example: Partial Derivatives¶

All nonzero partial derivatives for an ImplicitFuncComp must be declared when the function is wrapped. Otherwise, OpenMDAO will assume that all partial derivatives for that component are zero. For example, in the previous model we didn’t bother to declare partials because we weren’t computing derivatives. Now we’ll compute the total derivatives, so we need to declare the partials. Also, because our implicit function component does not define its own linearize function, we’ll specify a method of fd or cs when we declare our partials, so they’ll be computed using finite differencing or complex step. Finally, because our implicit component doesn’t define its own solve_linear function, we have to specify a linear solver for our component.

def apply_nl(a, b, c, x):
R_x = a * x ** 2 + b * x + c
return R_x

def solve_nonlinear(a, b, c, x):
x = (-b + (b ** 2 - 4 * a * c) ** 0.5) / (2 * a)
return x

f = (omf.wrap(apply_nl)
.declare_partials(of='*', wrt='*', method='cs')
)

p = om.Problem()

# need this since comp is implicit and doesn't have a solve_linear
comp.linear_solver = om.DirectSolver()

p.setup()

p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()

J = p.compute_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'])
print(('comp.x', 'comp.a'), J['comp.x', 'comp.a'])
print(('comp.x', 'comp.b'), J['comp.x', 'comp.b'])
print(('comp.x', 'comp.c'), J['comp.x', 'comp.c'])

('comp.x', 'comp.a') [[-2.25]]
('comp.x', 'comp.b') [[-0.75]]
('comp.x', 'comp.c') [[-0.25]]


## ImplicitFuncComp Example: Specifying linearize, solve_linear, and solve_nonlinear functions¶

The following implicit function component specifies linearize, solve_linear, and solve_nonlinear functions, so no external linear or nonlinear solvers are required to compute outputs or derivatives.

def apply_nl(a, b, c, x):
R_x = a * x ** 2 + b * x + c
return R_x

def solve_nonlinear(a, b, c, x):
x = (-b + (b ** 2 - 4 * a * c) ** 0.5) / (2 * a)
return x

def linearize(a, b, c, x, partials):
partials['x', 'a'] = x ** 2
partials['x', 'b'] = x
partials['x', 'c'] = 1.0
partials['x', 'x'] = 2 * a * x + b

inv_jac = 1.0 / (2 * a * x + b)
return inv_jac

def solve_linear(d_x, mode, inv_jac):
if mode == 'fwd':
d_x = inv_jac * d_x
return d_x
elif mode == 'rev':
dR_x = inv_jac * d_x
return dR_x

f = (omf.wrap(apply_nl)
.declare_partials(of='*', wrt='*')
)

p = om.Problem()
solve_nonlinear=solve_nonlinear,
solve_linear=solve_linear,
linearize=linearize))
p.setup()

p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()

J = p.compute_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'])
print(('comp.x', 'comp.a'), J['comp.x', 'comp.a'])
print(('comp.x', 'comp.b'), J['comp.x', 'comp.b'])
print(('comp.x', 'comp.c'), J['comp.x', 'comp.c'])

('comp.x', 'comp.a') [[-2.25]]
('comp.x', 'comp.b') [[-0.75]]
('comp.x', 'comp.c') [[-0.25]]