Unit Testing Partial Derivatives#
If you want to check the implementations of a Component
’s partial derivatives as part of a unit test, you can make use of a custom assert function, assert_check_partials
.
- openmdao.utils.assert_utils.assert_check_partials(data, atol=1e-06, rtol=1e-06, verbose=False, max_display_shape=(20, 20))[source]
Raise assertion if any entry from the return from check_partials is above a tolerance.
- Parameters:
- datadict of dicts of dicts
- First key:
is the component name;
- Second key:
is the (output, input) tuple of strings;
- Third key:
- is one of [‘tol violation’, ‘magnitude’, ‘J_fd’, ‘J_fwd’, ‘J_rev’,
‘vals_at_max_error’, ‘directional_fd_fwd’, ‘directional_fd_rev’, ‘directional_fwd_rev’, ‘rank_inconsistent’, ‘matrix_free’, ‘directional’, ‘steps’, and ‘rank_inconsistent’].
For ‘J_fd’, ‘J_fwd’, ‘J_rev’ the value is a numpy array representing the computed Jacobian for the three different methods of computation. For ‘tol violation’ and ‘vals_at_max_error’ the value is a tuple containing values for forward - fd, reverse - fd, forward - reverse. For ‘magnitude’ the value is a tuple indicating the maximum magnitude of values found in Jfwd, Jrev, and Jfd.
- atolfloat
Absolute error. Default is 1e-6.
- rtolfloat
Relative error. Default is 1e-6.
- verbosebool
When True, display more jacobian information.
- max_display_shapetuple of int
Maximum shape of the jacobians to display directly in the error message. Default is (20, 20). Only active if verbose is True.
In your unit test, after calling check_partials
on a Component, you can call the assert_check_partials
function with the returned value from check_partials
.
Usage#
In the following code, compute_partials
is intentionally coded incorrectly to demonstrate how assert_check_partials
can be used to detect this kind of error.
import numpy as np
import openmdao.api as om
from openmdao.utils.assert_utils import assert_check_partials
class BrokenDerivComp(om.ExplicitComponent):
def setup(self):
self.add_input('x1', 3.0)
self.add_input('x2', 5.0)
self.add_output('y', 5.5)
def setup_partials(self):
self.declare_partials(of='*', wrt='*')
def compute(self, inputs, outputs):
""" Compute outputs. """
outputs['y'] = 3.0 * inputs['x1'] + 4.0 * inputs['x2']
def compute_partials(self, inputs, partials):
"""Intentionally incorrect derivative."""
J = partials
J['y', 'x1'] = np.array([4.0])
J['y', 'x2'] = np.array([40])
prob = om.Problem()
prob.model.add_subsystem('comp', BrokenDerivComp())
prob.set_solver_print(level=0)
prob.setup()
prob.run_model()
data = prob.check_partials(out_stream=None)
print(data)
try:
assert_check_partials(data, atol=1.e-6, rtol=1.e-6)
except ValueError as err:
print(str(err))
{'comp': {('y', 'x1'): {'J_fwd': array([[4.]]), 'J_fd': array([[3.]]), 'rows': None, 'cols': None, 'tol violation': _ErrorData(forward=0.999997000468845, reverse=None, fwd_rev=None), 'magnitude': _MagnitudeData(forward=4.0, reverse=0.0, fd=2.9999999995311555), 'vals_at_max_error': _ErrorData(forward=(4.0, 2.9999999995311555), reverse=None, fwd_rev=None), 'abs error': _ErrorData(forward=1.0000000004688445, reverse=None, fwd_rev=None), 'rel error': _ErrorData(forward=0.3333333335417087, reverse=None, fwd_rev=None)}, ('y', 'x2'): {'J_fwd': array([[40.]]), 'J_fd': array([[4.]]), 'rows': None, 'cols': None, 'tol violation': _ErrorData(forward=35.999995999440884, reverse=None, fwd_rev=None), 'magnitude': _MagnitudeData(forward=40.0, reverse=0.0, fd=4.000000000559112), 'vals_at_max_error': _ErrorData(forward=(40.0, 4.000000000559112), reverse=None, fwd_rev=None), 'abs error': _ErrorData(forward=35.99999999944089, reverse=None, fwd_rev=None), 'rel error': _ErrorData(forward=8.99999999860222, reverse=None, fwd_rev=None)}}}
==============================================================
assert_check_partials failed for the following Components
with absolute tolerance = 1e-06 and relative tolerance = 1e-06
==============================================================
---------------
Component: comp
---------------
< output > wrt < variable > | max abs/rel | diff | value
---------------------------------------------------------------
y wrt x1 | abs | fd-fwd | 1.0
y wrt x1 | rel | fd-fwd | 0.33333333
y wrt x2 | abs | fd-fwd | 36.0
y wrt x2 | rel | fd-fwd | 9.0