ImplicitFuncComp#
ImplicitFuncComp
is a component that provides a shortcut for building an ImplicitComponent based on a python function. That function takes inputs and states as arguments and returns residual values. The function may take some inputs that are non-differentiable and are assumed to be static during the computation of derivatives. These static values may be of any hashable type. All other arguments and return values must be either floats or numpy arrays. The mapping between a state argument and its residual output must be specified in the metadata when the output (state) is added by setting ‘resid’ to the name of the residual.
It may seem confusing to use add_output
to specify state variables since the state variables
are actually input arguments to the function, but in OpenMDAO’s view of the world, states are outputs so we use add_output
to specify them. Also, using the metadata to specify which input arguments are actually states gives more flexibility in terms of how the function arguments are ordered. For example, if it’s desirable for a function to be passable to scipy.optimize.newton
, then the function’s arguments can be ordered with the states first, followed by the inputs, in order to match the order expected by scipy.optimize.newton
.
The add_output
function is part of the Function Metadata API. You use this API to specify various metadata that OpenMDAO needs in order to properly configure a fully functional implicit component. You should read and understand the Function Metadata API before you continue with this section.
ImplicitFuncComp Options#
Option | Default | Acceptable Values | Acceptable Types | Description |
---|---|---|---|---|
always_opt | False | [True, False] | ['bool'] | If True, force nonlinear operations on this component to be included in the optimization loop even if this component is not relevant to the design variables and responses. |
assembled_jac_type | csc | ['csc', 'dense'] | N/A | Linear solver(s) in this group or implicit component, if using an assembled jacobian, will use this type. |
derivs_method | N/A | ['jax', 'cs', 'fd', None] | N/A | The method to use for computing derivatives |
distributed | False | [True, False] | ['bool'] | If True, set all variables in this component as distributed across multiple processes |
run_root_only | False | [True, False] | ['bool'] | If True, call compute, compute_partials, linearize, apply_linear, apply_nonlinear, and compute_jacvec_product only on rank 0 and broadcast the results to the other ranks. |
use_jit | True | [True, False] | ['bool'] | If True, attempt to use jit on compute_primal, assuming jax or some other AD package is active. |
ImplicitFuncComp Constructor#
The call signature for the ImplicitFuncComp
constructor is:
- ImplicitFuncComp.__init__(apply_nonlinear, solve_nonlinear=None, linearize=None, solve_linear=None, **kwargs)[source]
Initialize attributes.
ImplicitFuncComp Example: A simple implicit function component#
The simplest implicit function component requires the definition of a function that takes
inputs and states as arguments and returns residual values. This function maps to the apply_nonlinear
method in the OpenMDAO component API. Here’s an example:
import openmdao.api as om
import openmdao.func_api as omf
def apply_nl(a, b, c, x): # inputs a, b, c and state x
R_x = a * x ** 2 + b * x + c
return R_x
f = (omf.wrap(apply_nl)
.add_output('x', resid='R_x', val=0.0)
.declare_partials(of='*', wrt='*', method='cs')
)
p = om.Problem()
p.model.add_subsystem('comp', om.ImplicitFuncComp(f))
p.model.nonlinear_solver = om.NewtonSolver(solve_subsystems=False, iprint=0)
# need this since comp is implicit and doesn't have a solve_linear
p.model.linear_solver = om.DirectSolver()
p.setup()
p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()
ImplicitFuncComp Example: Partial Derivatives#
All nonzero partial derivatives for an ImplicitFuncComp must be declared when the function is wrapped. Otherwise, OpenMDAO will assume that all partial derivatives for that component are zero. For example, in the previous model we didn’t bother to declare partials because we weren’t computing derivatives. Now we’ll compute the total derivatives, so we need to declare the partials. Also, because our implicit function component does not define its own linearize
function, we’ll specify a method
of fd
or cs
when we declare our partials, so they’ll be computed using finite differencing or complex step. Finally, because our implicit component doesn’t define its own solve_linear
function, we have to specify a linear solver for our component.
def apply_nl(a, b, c, x):
R_x = a * x ** 2 + b * x + c
return R_x
def solve_nonlinear(a, b, c, x):
x = (-b + (b ** 2 - 4 * a * c) ** 0.5) / (2 * a)
return x
f = (omf.wrap(apply_nl)
.add_output('x', resid='R_x', val=0.0)
.declare_partials(of='*', wrt='*', method='cs')
)
p = om.Problem()
comp = p.model.add_subsystem('comp', om.ImplicitFuncComp(f, solve_nonlinear=solve_nonlinear))
# need this since comp is implicit and doesn't have a solve_linear
comp.linear_solver = om.DirectSolver()
p.setup()
p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()
J = p.compute_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'])
print(('comp.x', 'comp.a'), J['comp.x', 'comp.a'])
print(('comp.x', 'comp.b'), J['comp.x', 'comp.b'])
print(('comp.x', 'comp.c'), J['comp.x', 'comp.c'])
('comp.x', 'comp.a') [[-2.25]]
('comp.x', 'comp.b') [[-0.75]]
('comp.x', 'comp.c') [[-0.25]]
ImplicitFuncComp Example: Specifying linearize, solve_linear, and solve_nonlinear functions#
The following implicit function component specifies linearize
, solve_linear
, and solve_nonlinear
functions, so no external linear or nonlinear solvers are required to compute outputs or derivatives.
def apply_nl(a, b, c, x):
R_x = a * x ** 2 + b * x + c
return R_x
def solve_nonlinear(a, b, c, x):
x = (-b + (b ** 2 - 4 * a * c) ** 0.5) / (2 * a)
return x
def linearize(a, b, c, x, partials):
partials['x', 'a'] = x ** 2
partials['x', 'b'] = x
partials['x', 'c'] = 1.0
partials['x', 'x'] = 2 * a * x + b
inv_jac = 1.0 / (2 * a * x + b)
return inv_jac
def solve_linear(d_x, mode, inv_jac):
if mode == 'fwd':
d_x = inv_jac * d_x
return d_x
elif mode == 'rev':
dR_x = inv_jac * d_x
return dR_x
f = (omf.wrap(apply_nl)
.add_output('x', resid='R_x', val=0.0)
.declare_partials(of='*', wrt='*')
)
p = om.Problem()
p.model.add_subsystem('comp', om.ImplicitFuncComp(f,
solve_nonlinear=solve_nonlinear,
solve_linear=solve_linear,
linearize=linearize))
p.setup()
p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()
J = p.compute_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'])
print(('comp.x', 'comp.a'), J['comp.x', 'comp.a'])
print(('comp.x', 'comp.b'), J['comp.x', 'comp.b'])
print(('comp.x', 'comp.c'), J['comp.x', 'comp.c'])
('comp.x', 'comp.a') [[-2.25]]
('comp.x', 'comp.b') [[-0.75]]
('comp.x', 'comp.c') [[-0.25]]
Using jax to Compute Partial Derivatives#
If the function used to instantiate the ExplicitFuncComp declares partials or coloring that use method='jax'
, or if the component’s derivs_method
option is set to jax
, then the jax AD package will be used to compute all of the component’s derivatives. Currently it’s not possible to mix jax with finite difference methods (‘cs’ and ‘fd’) in the same component.
Note that jax
is an optional OpenMDAO dependency, but you can install it manually.
pip install jax
pip install jaxlib
should work in most cases.
To activate jax’s just-in-time compilation capability, set the use_jit
option on the component. For example:
p = om.Problem()
p.model.add_subsystem('comp', om.ImplicitFuncComp(f,
solve_nonlinear=solve_nonlinear,
solve_linear=solve_linear,
linearize=linearize,
use_jit=True))
p.setup()
p.set_val('comp.a', 2.)
p.set_val('comp.b', -8.)
p.set_val('comp.c', 6.)
p.run_model()
J = p.compute_totals(of=['comp.x'], wrt=['comp.a', 'comp.b', 'comp.c'])
print(('comp.x', 'comp.a'), J['comp.x', 'comp.a'])
print(('comp.x', 'comp.b'), J['comp.x', 'comp.b'])
print(('comp.x', 'comp.c'), J['comp.x', 'comp.c'])
('comp.x', 'comp.a') [[-2.25]]
('comp.x', 'comp.b') [[-0.75]]
('comp.x', 'comp.c') [[-0.25]]