"""
Functions for making assertions about OpenMDAO Systems.
"""
from fnmatch import fnmatch
import warnings
import unittest
from contextlib import contextmanager
from functools import wraps
import numpy as np
try:
from jaxlib.xla_extension import ArrayImpl
except ImportError:
ArrayImpl = None
from openmdao.core.component import Component
from openmdao.core.group import Group
from openmdao.jacobians.dictionary_jacobian import DictionaryJacobian
from openmdao.utils.general_utils import pad_name
from openmdao.utils.om_warnings import reset_warning_registry
from openmdao.utils.mpi import MPI
from openmdao.utils.testing_utils import snum_equal
[docs]@contextmanager
def assert_warning(category, msg, contains_msg=False, ranks=None):
"""
Context manager asserting that a warning is issued.
Parameters
----------
category : class
The class of the expected warning.
msg : str
The text of the expected warning.
contains_msg : bool
Set to True to check that the warning text contains msg, rather than checking equality.
ranks : int or list of int, optional
The global ranks on which the warning is expected.
Yields
------
None
Raises
------
AssertionError
If the expected warning is not raised.
"""
with reset_warning_registry():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield
if ranks is not None:
if MPI is None:
raise RuntimeError("ranks argument has been specified but MPI is not active")
else:
if not isinstance(ranks, list):
ranks = [ranks]
if MPI.COMM_WORLD.rank not in ranks:
return
for warn in w:
if contains_msg:
warn_clause = msg in str(warn.message)
else:
warn_clause = str(warn.message) == msg
if (issubclass(warn.category, category) and warn_clause):
break
else:
msg = f"Did not see expected {category.__name__}:\n{msg}"
if w:
ws = '\n'.join([str(ww.message) for ww in w])
categories = '\n'.join([str(ww.category.__name__) for ww in w])
msg += f"\nDid see warnings [{categories}]:\n{ws}"
raise AssertionError(msg)
[docs]@contextmanager
def assert_warnings(expected_warnings):
"""
Context manager asserting that expected warnings are issued.
Parameters
----------
expected_warnings : iterable of (class, str)
The category and text of the expected warnings.
Yields
------
None
Raises
------
AssertionError
If all the expected warnings are not raised.
"""
with reset_warning_registry():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield
for category, msg in expected_warnings:
for warn in w:
if (issubclass(warn.category, category) and str(warn.message) == msg):
break
else:
raise AssertionError("Did not see expected %s: %s" % (category.__name__, msg))
[docs]@contextmanager
def assert_no_warning(category, msg=None, contains=False):
"""
Context manager asserting that a warning is not issued.
Parameters
----------
category : class
The class of the warning.
msg : str or None
The text of the warning. If None then only the warning class will be checked.
contains : bool
If True, check that the warning text contains msg, rather than checking equality.
Yields
------
None
Raises
------
AssertionError
If the warning is raised.
"""
with reset_warning_registry():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield
for warn in w:
if issubclass(warn.category, category):
if msg is None:
raise AssertionError(f"Found warning: {category} {str(warn.message)}")
elif contains:
if msg in str(warn.message):
raise AssertionError(f"Found warning: {category} containing '{msg}'")
elif str(warn.message) == msg:
raise AssertionError(f"Found warning: {category} {msg}")
[docs]def assert_check_partials(data, atol=1e-6, rtol=1e-6):
"""
Raise assertion if any entry from the return from check_partials is above a tolerance.
Parameters
----------
data : dict of dicts of dicts
First key:
is the component name;
Second key:
is the (output, input) tuple of strings;
Third key:
is one of ['rel error', 'abs error', 'magnitude', 'J_fd', 'J_fwd', 'J_rev'];
For 'rel error', 'abs error', 'magnitude' the value is: A tuple containing norms for
forward - fd, adjoint - fd, forward - adjoint.
For 'J_fd', 'J_fwd', 'J_rev' the value is: A numpy array representing the computed
Jacobian for the three different methods of computation.
atol : float
Absolute error. Default is 1e-6.
rtol : float
Relative error. Default is 1e-6.
"""
error_string = ''
absrel_header = 'abs/rel'
wrt_header = '< output > wrt < variable >'
norm_value_header = 'norm value'
len_absrel_width = len(absrel_header)
norm_types = ['fwd-fd', 'rev-fd', 'fd-rev']
len_norm_type_width = max(len(s) for s in norm_types)
for comp in data:
len_wrt_width = len(wrt_header)
len_norm_width = len(norm_value_header)
bad_derivs = []
inconsistent_derivs = set()
# Find all derivatives whose errors exceed tolerance.
# Also, size the output to precompute column extents.
for key, pair_data in data[comp].items():
var, wrt = key
for error_type, tolerance in [('abs error', atol), ('rel error', rtol), ]:
actuals = pair_data[error_type]
if not isinstance(actuals, list):
actuals = [actuals]
incon = pair_data.get('rank_inconsistent')
if incon:
inconsistent_derivs.add(key)
for actual in actuals:
for error_val, mode in zip(actual, norm_types):
in_error = False
if error_val is None:
# Reverse derivatives only computed on matrix free comps.
continue
if not np.isnan(error_val):
if not np.allclose(error_val, 0.0, atol=tolerance):
if error_type == 'rel error' and mode == 'fwd-fd' and \
np.allclose(pair_data['J_fwd'], 0.0, atol=atol) and \
np.allclose(pair_data['J_fd'], 0.0, atol=atol):
# Special case: both fd and fwd are really tiny, so we want to
# ignore the rather large relative errors.
in_error = False
else:
# This is a bona-fide error.
in_error = True
elif error_type == 'abs error' and mode == 'fwd-fd':
# Either analytic or approximated derivatives contain a NaN.
in_error = True
if in_error:
wrt_string = f'{var} wrt {wrt}'
norm_string = str(error_val)
bad_derivs.append((wrt_string, norm_string, error_type, mode))
len_wrt_width = max(len_wrt_width, len(wrt_string))
len_norm_width = max(len_norm_width, len(norm_string))
if bad_derivs or inconsistent_derivs:
comp_error_string = ''
if bad_derivs:
for wrt_string, norm_string, error_type, mode in bad_derivs:
err_msg = '{0} | {1} | {2} | {3}'.format(
pad_name(wrt_string, len_wrt_width),
pad_name(error_type.split()[0], len_absrel_width),
pad_name(mode, len_norm_type_width),
pad_name(norm_string, len_norm_width)) + '\n'
comp_error_string += err_msg
if inconsistent_derivs:
comp_error_string += (
"\nInconsistent derivs across processes for keys: "
f"{sorted(inconsistent_derivs)}.\nCheck that distributed outputs are properly "
"reduced when computing\nderivatives of serial inputs.")
name_header = 'Component: {}\n'.format(comp)
len_name_header = len(name_header)
header = len_name_header * '-' + '\n'
header += name_header
header += len_name_header * '-' + '\n'
header += '{0} | {1} | {2} | {3}'.format(
pad_name(wrt_header, len_wrt_width),
pad_name(absrel_header, len_absrel_width),
pad_name('norm', len_norm_type_width),
pad_name(norm_value_header, len_norm_width),
) + '\n'
header += '{0} | {1} | {2} | {3}'.format(
len_wrt_width * '-',
len_absrel_width * '-',
len_norm_type_width * '-',
len_norm_width * '-',
) + '\n'
comp_error_string = header + comp_error_string
error_string += comp_error_string
# if error string then raise error with that string
if error_string:
header_line1 = 'Assert Check Partials failed for the following Components'
header_line2 = f'with absolute tolerance = {atol} and relative tolerance = {rtol}'
header_width = max(len(header_line1), len(header_line2))
header = '\n' + header_width * '=' + '\n'
header += header_line1 + '\n'
header += header_line2 + '\n'
header += header_width * '=' + '\n'
error_string = header + error_string
raise ValueError(error_string)
[docs]def assert_check_totals(totals_data, atol=1e-6, rtol=1e-6):
"""
Raise assertion if any entry from the return from check_totals is above a tolerance.
Parameters
----------
totals_data : Dict of Dicts of Tuples of Floats
First key:
is the (output, input) tuple of strings;
Second key:
is one of ['rel error', 'abs error', 'magnitude', 'fdstep'];
For 'rel error', 'abs error', 'magnitude' the value is: A tuple containing norms for
forward - fd, adjoint - fd, forward - adjoint.
atol : float
Absolute error. Default is 1e-6.
rtol : float
Relative error. Default is 1e-6.
"""
fails = []
incon_keys = set()
for key, dct in totals_data.items():
if 'inconsistent_keys' in dct:
incon_keys = dct['inconsistent_keys']
Jname = 'J_fwd' if 'J_fwd' in dct else 'J_rev'
try:
dct[Jname]
dct['J_fd']
except Exception as err:
raise err.__class__(f"For key {key}: {err}")
try:
for i in range(3):
erel, eabs = dct['rel error'][i], dct['abs error'][i]
if erel is not None and not np.isnan(erel):
if erel > rtol:
raise ValueError(f"rel tolerance of {erel} > allowed rel tolerance "
f"of {rtol}.")
if eabs is not None:
if eabs > atol:
raise ValueError(f"abs tolerance of {eabs} > allowed abs tolerance "
f"of {atol}.")
except ValueError as err:
fails.append((key, dct, err, Jname))
fail_list = []
if fails:
fail_list.extend(
[f"Totals differ for {key}:\nAnalytic:\n{dct[Jname]}\nFD:\n{dct['J_fd']}\n{err}"
for key, dct, err, Jname in fails])
if incon_keys:
ders = [f"{sof} wrt {swrt}" for sof, swrt in sorted(incon_keys)]
fail_list.append(f"During total derivative computation, the following partial derivatives "
"resulted in serial inputs that were inconsistent across processes: "
f"{ders}.")
if fails or incon_keys:
raise ValueError('\n\n'.join(fail_list))
[docs]def assert_no_approx_partials(system, include_self=True, recurse=True, method='any', excludes=None):
"""
Raise assertion error if any component within system is using approximated partials.
Parameters
----------
system : System
The system under which to search for approximated partials.
include_self : bool
If True, include this system in the iteration.
recurse : bool
If True, iterate over the whole tree under this system.
method : str
Specifically look for Components with this method of approx partials. Values can be
'cs', 'fd', or 'any'. 'any' means either 'cs' or 'fd'. The default is 'any'.
excludes : str, iter of str, or None
Glob patterns for pathnames to exclude from the check. Default is None, which
excludes nothing.
Raises
------
AssertionError
If a subsystem of group is found to be using approximated partials.
"""
if isinstance(excludes, str):
excludes = [excludes, ]
has_approx_partials = False
msg = 'The following components use approximated partials:\n'
for s in system.system_iter(include_self=include_self, recurse=recurse):
if isinstance(s, Component):
if excludes is not None and any(fnmatch(s.pathname, exclude) for exclude in excludes):
continue
if s._approx_schemes:
if method == 'any' or method in s._approx_schemes:
has_approx_partials = True
approx_partials = [(k, v['method'])
for k, v in s._declared_partials_patterns.items()
if 'method' in v and v['method']]
msg += ' ' + s.pathname + '\n'
for key, method in approx_partials:
key = (str(key[0]), str(key[1]))
msg += ' of={0:12s} wrt={1:12s} method={2:2s}\n'.format(key[0],
key[1],
method)
if has_approx_partials:
raise AssertionError(msg)
[docs]def assert_no_dict_jacobians(system, include_self=True, recurse=True):
"""
Raise an assertion error if any Group within system is found to be using dictionary jacobians.
Parameters
----------
system : System
The system under which to search for approximated partials.
include_self : bool
If True, include this system in the iteration.
recurse : bool
If True, iterate over the whole tree under this system.
Raises
------
AssertionError
If a subsystem of group is found to be using approximated partials.
"""
parts = ['The following groups use dictionary jacobians:\n']
for s in system.system_iter(include_self=include_self, recurse=recurse, typ=Group):
if isinstance(s._jacobian, DictionaryJacobian):
parts.append(' ' + s.pathname)
if len(parts) > 1:
raise AssertionError('\n'.join(parts))
[docs]def assert_near_equal(actual, desired, tolerance=1e-15, tol_type='rel'):
"""
Check relative error.
Determine that the relative error between `actual` and `desired`
is within `tolerance`. If `desired` is zero, then use absolute error.
Can handle some data structures. Generates warnings for data types it cannot handle.
Parameters
----------
actual : float, array-like, dict
The value from the test.
desired : float, array-like, dict
The value expected.
tolerance : float
Maximum relative or absolute error.
For relative tolerance: ``(actual - desired) / desired``.
For absolute tolerance: ``(actual - desired)``.
tol_type : {'rel', 'abs'}
Type of error to use: 'rel' for relative error, 'abs' for absolute error.
Default is set to 'rel'.
Returns
-------
float
The error.
"""
# Try to make similar things of the same type so they can be compared
# make arrays out of scalars
if type(actual) in [int, float, np.int64, np.float64, np.int32, np.complex128]:
actual = np.atleast_1d(actual)
if type(desired) in [int, float, np.int64, np.float64, np.int32, np.complex128]:
desired = np.atleast_1d(desired)
# Handle jax arrays, if available
if ArrayImpl is not None:
if isinstance(actual, ArrayImpl):
actual = np.atleast_1d(actual)
if isinstance(desired, ArrayImpl):
desired = np.atleast_1d(desired)
# if desired is numeric list or tuple, make ndarray out of it
if isinstance(actual, (list, tuple)):
actual = np.asarray(actual)
if isinstance(desired, (list, tuple)):
desired = np.asarray(desired)
# In case they are PromAbsDict and other dict-like objects
if isinstance(actual, dict) and type(actual) is not dict:
actual = dict(actual)
if isinstance(desired, dict) and type(desired) is not dict:
desired = dict(desired)
if type(actual) is not type(desired):
raise ValueError(f'actual {type(actual)}, desired {type(desired)} have different types')
if isinstance(actual, type) and isinstance(desired, type):
if actual != desired:
raise ValueError(
'actual type %s, and desired type %s are different' % (actual, desired))
return 0
# The code below can only handle these data types
_supported_types = [dict, set, str, bool, np.ndarray, type(None)]
if type(actual) not in _supported_types:
warnings.warn(
f"The function, assert_near_equal, does not support the actual value type: '"
f"{type(actual)}'.")
return 0
if type(desired) not in _supported_types:
warnings.warn(
f"The function, assert_near_equal, does not support the desired value type: '"
f"{type(actual)}'.")
return 0
if isinstance(actual, dict) and isinstance(desired, dict):
actual_keys = set(actual.keys())
desired_keys = set(desired.keys())
if actual_keys.symmetric_difference(desired_keys):
msg = 'Actual and desired keys differ. Actual extra keys: {}, Desired extra keys: {}'
actual_extra = actual_keys.difference(desired_keys)
desired_extra = desired_keys.difference(actual_keys)
raise KeyError(msg.format(actual_extra, desired_extra))
error = 0.
for key in actual_keys:
try:
new_error = assert_near_equal(actual[key], desired[key], tolerance, tol_type)
error = max(error, new_error)
except ValueError as exception:
msg = '{}: '.format(key) + str(exception)
raise ValueError(msg) from None
except KeyError as exception:
msg = '{}: '.format(key) + str(exception)
raise KeyError(msg) from None
elif isinstance(actual, set) and isinstance(desired, set):
if actual.symmetric_difference(desired):
actual_extra = actual.difference(desired)
desired = desired.difference(actual)
raise KeyError("Actual and desired sets differ. "
f"Actual extra values: {actual_extra}, "
f"Desired extra values: {desired_extra}")
error = 0.
elif isinstance(actual, str) and isinstance(desired, str):
if actual != desired:
raise ValueError(
'actual %s, desired %s strings have different values' % (actual, desired))
error = 0.0
elif isinstance(actual, bool) and isinstance(desired, bool):
if actual != desired:
raise ValueError(
'actual %s, desired %s booleans have different values' % (actual, desired))
error = 0.0
elif actual is None and desired is None:
error = 0.0
# array values
elif isinstance(actual, np.ndarray) and isinstance(desired, np.ndarray):
if actual.dtype == object or desired.dtype == object:
if actual.dtype == object:
warnings.warn(
f"The function, assert_near_equal, does not support the actual value ndarray "
f"type of: '"
f"{type(actual.dtype)}'.")
if desired.dtype == object:
warnings.warn(
f"The function, assert_near_equal, does not support the desired value ndarray "
f"type of: '"
f"{type(desired.dtype)}'.")
error = 0.0
else:
actual = np.atleast_1d(actual)
desired = np.atleast_1d(desired)
if actual.shape != desired.shape:
raise ValueError(
'actual and desired have differing shapes.'
' actual {}, desired {}'.format(actual.shape, desired.shape))
# check to see if the entire array is made of floats. If not, loop through all values
if not np.all(np.isnan(actual) == np.isnan(desired)):
if actual.size == 1 and desired.size == 1:
raise ValueError('actual %s, desired %s' % (actual, desired))
else:
raise ValueError('actual and desired values have non-matching nan'
' values')
if np.linalg.norm(desired) == 0 or tol_type == 'abs':
error = np.linalg.norm(actual - desired)
else:
error = np.linalg.norm(actual - desired) / np.linalg.norm(desired)
if abs(error) > tolerance:
if actual.size < 10 and desired.size < 10:
raise ValueError('actual %s, desired %s, %s error %s, tolerance %s'
% (actual, desired, tol_type, error, tolerance))
else:
raise ValueError('arrays do not match, rel error %.3e > tol (%.3e)' %
(error, tolerance))
elif isinstance(actual, tuple) and isinstance(desired, tuple):
error = 0.0
for act, des in zip(actual, desired):
new_error = assert_near_equal(act, des, tolerance, tol_type)
error = max(error, new_error)
else:
raise ValueError(
'actual and desired have unexpected types: %s, %s' % (type(actual), type(desired)))
return error
[docs]def assert_equal_arrays(a1, a2):
"""
Check that two arrays are equal.
This is a simplified method useful when the arrays to be compared may
not be numeric. It simply compares the shapes of the two arrays and then
does a value by value comparison.
Parameters
----------
a1 : array
The first array to compare.
a2 : array
The second array to compare.
"""
assert a1.shape == a2.shape
for x, y in zip(a1.flat, a2.flat):
assert x == y
[docs]def assert_equal_numstrings(s1, s2, atol=1e-6, rtol=1e-6):
"""
Check that two strings containing numbers are equal after convering numerical parts to floats.
Parameters
----------
s1 : str
The first numeric string to compare.
s2 : str
The second numeric string to compare.
atol : float
Absolute error tolerance. Default is 1e-6.
rtol : float
Relative error tolerance. Default is 1e-6.
"""
assert snum_equal(s1, s2, atol=atol, rtol=rtol)
[docs]def skip_helper(msg):
"""
Raise a SkipTest.
Parameters
----------
msg : str
The skip messaage.
Raises
------
SkipTest
"""
raise unittest.SkipTest(msg)
[docs]class SkipParameterized(object):
"""
Replaces the parameterized class, skipping decorated test cases.
"""
[docs] @classmethod
def expand(cls, input, name_func=None, doc_func=None, skip_on_empty=False, **legacy):
"""
Decorate a test so that it raises a SkipTest.
Parameters
----------
input : iterable
Not used (part of parameterized API).
name_func : function
Not used (part of parameterized API).
doc_func : function
Not used (part of parameterized API).
skip_on_empty : bool
Not used (part of parameterized API).
**legacy : dict
Not used (part of parameterized API).
Returns
-------
function
The wrapper function.
"""
skip_msg = "requires 'parameterized' (install openmdao[test])"
def parameterized_expand_wrapper(f, instance=None):
"""
Wrap a function so that it raises a SkipTest.
f : function
Function to be wrapped.
instance : None
Not used (part of parameterized API).
Returns
-------
function
The wrapped function.
"""
return wraps(f)(lambda f: skip_helper(skip_msg))
return parameterized_expand_wrapper