"""
Functions for making assertions about OpenMDAO Systems.
"""
from math import isnan
from fnmatch import fnmatch
import warnings
import unittest
from contextlib import contextmanager
from functools import wraps
import numpy as np
try:
from jaxlib.xla_extension import ArrayImpl
except ImportError:
ArrayImpl = None
from openmdao.core.component import Component
from openmdao.core.group import Group
from openmdao.jacobians.dictionary_jacobian import DictionaryJacobian
from openmdao.utils.general_utils import pad_name
from openmdao.utils.om_warnings import reset_warning_registry
from openmdao.utils.mpi import MPI
from openmdao.utils.testing_utils import snum_equal
[docs]@contextmanager
def assert_warning(category, msg, contains_msg=False, ranks=None):
"""
Context manager asserting that a warning is issued.
Parameters
----------
category : class
The class of the expected warning.
msg : str
The text of the expected warning.
contains_msg : bool
Set to True to check that the warning text contains msg, rather than checking equality.
ranks : int or list of int, optional
The global ranks on which the warning is expected.
Yields
------
None
Raises
------
AssertionError
If the expected warning is not raised.
"""
with reset_warning_registry():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield
if ranks is not None:
if MPI is None:
raise RuntimeError("ranks argument has been specified but MPI is not active")
else:
if not isinstance(ranks, list):
ranks = [ranks]
if MPI.COMM_WORLD.rank not in ranks:
return
for warn in w:
if contains_msg:
warn_clause = msg in str(warn.message)
else:
warn_clause = str(warn.message) == msg
if (issubclass(warn.category, category) and warn_clause):
break
else:
msg = f"Did not see expected {category.__name__}:\n{msg}"
if w:
ws = '\n'.join([str(ww.message) for ww in w])
categories = '\n'.join([str(ww.category.__name__) for ww in w])
msg += f"\nDid see warnings [{categories}]:\n{ws}"
raise AssertionError(msg)
[docs]@contextmanager
def assert_warnings(expected_warnings):
"""
Context manager asserting that expected warnings are issued.
Parameters
----------
expected_warnings : iterable of (class, str)
The category and text of the expected warnings.
Yields
------
None
Raises
------
AssertionError
If all the expected warnings are not raised.
"""
with reset_warning_registry():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield
for category, msg in expected_warnings:
for warn in w:
if (issubclass(warn.category, category) and str(warn.message) == msg):
break
else:
raise AssertionError("Did not see expected %s: %s" % (category.__name__, msg))
[docs]@contextmanager
def assert_no_warning(category, msg=None):
"""
Context manager asserting that a warning is not issued.
Parameters
----------
category : class
The class of the warning.
msg : str or None
The text of the warning. If None then only the warning class will be checked.
Yields
------
None
Raises
------
AssertionError
If the warning is raised.
"""
with reset_warning_registry():
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
yield
for warn in w:
if issubclass(warn.category, category):
if msg is None:
raise AssertionError(f"Found warning: {category} {str(warn.message)}")
elif str(warn.message) == msg:
raise AssertionError(f"Found warning: {category} {msg}")
[docs]def assert_check_partials(data, atol=1e-6, rtol=1e-6):
"""
Raise assertion if any entry from the return from check_partials is above a tolerance.
Parameters
----------
data : dict of dicts of dicts
First key:
is the component name;
Second key:
is the (output, input) tuple of strings;
Third key:
is one of ['rel error', 'abs error', 'magnitude', 'J_fd', 'J_fwd', 'J_rev'];
For 'rel error', 'abs error', 'magnitude' the value is: A tuple containing norms for
forward - fd, adjoint - fd, forward - adjoint.
For 'J_fd', 'J_fwd', 'J_rev' the value is: A numpy array representing the computed
Jacobian for the three different methods of computation.
atol : float
Absolute error. Default is 1e-6.
rtol : float
Relative error. Default is 1e-6.
"""
error_string = ''
absrel_header = 'abs/rel'
wrt_header = '< output > wrt < variable >'
norm_value_header = 'norm value'
len_absrel_width = len(absrel_header)
norm_types = ['fwd-fd', 'rev-fd', 'fd-rev']
len_norm_type_width = max(len(s) for s in norm_types)
for comp in data:
len_wrt_width = len(wrt_header)
len_norm_width = len(norm_value_header)
bad_derivs = []
inconsistent_derivs = set()
# Find all derivatives whose errors exceed tolerance.
# Also, size the output to precompute column extents.
for key, pair_data in data[comp].items():
var, wrt = key
for error_type, tolerance in [('abs error', atol), ('rel error', rtol), ]:
actuals = pair_data[error_type]
if not isinstance(actuals, list):
actuals = [actuals]
incon = pair_data.get('rank_inconsistent')
if incon:
inconsistent_derivs.add(key)
for actual in actuals:
for error_val, mode in zip(actual, norm_types):
in_error = False
if error_val is None:
# Reverse derivatives only computed on matrix free comps.
continue
if not np.isnan(error_val):
if not np.allclose(error_val, 0.0, atol=tolerance):
if error_type == 'rel error' and mode == 'fwd-fd' and \
np.allclose(pair_data['J_fwd'], 0.0, atol=atol) and \
np.allclose(pair_data['J_fd'], 0.0, atol=atol):
# Special case: both fd and fwd are really tiny, so we want to
# ignore the rather large relative errors.
in_error = False
else:
# This is a bona-fide error.
in_error = True
elif error_type == 'abs error' and mode == 'fwd-fd':
# Either analytic or approximated derivatives contain a NaN.
in_error = True
if in_error:
wrt_string = f'{var} wrt {wrt}'
norm_string = str(error_val)
bad_derivs.append((wrt_string, norm_string, error_type, mode))
len_wrt_width = max(len_wrt_width, len(wrt_string))
len_norm_width = max(len_norm_width, len(norm_string))
if bad_derivs or inconsistent_derivs:
comp_error_string = ''
if bad_derivs:
for wrt_string, norm_string, error_type, mode in bad_derivs:
err_msg = '{0} | {1} | {2} | {3}'.format(
pad_name(wrt_string, len_wrt_width),
pad_name(error_type.split()[0], len_absrel_width),
pad_name(mode, len_norm_type_width),
pad_name(norm_string, len_norm_width)) + '\n'
comp_error_string += err_msg
if inconsistent_derivs:
comp_error_string += (
"\nInconsistent derivs across processes for keys: "
f"{sorted(inconsistent_derivs)}.\nCheck that distributed outputs are properly "
"reduced when computing\nderivatives of serial inputs.")
name_header = 'Component: {}\n'.format(comp)
len_name_header = len(name_header)
header = len_name_header * '-' + '\n'
header += name_header
header += len_name_header * '-' + '\n'
header += '{0} | {1} | {2} | {3}'.format(
pad_name(wrt_header, len_wrt_width),
pad_name(absrel_header, len_absrel_width),
pad_name('norm', len_norm_type_width),
pad_name(norm_value_header, len_norm_width),
) + '\n'
header += '{0} | {1} | {2} | {3}'.format(
len_wrt_width * '-',
len_absrel_width * '-',
len_norm_type_width * '-',
len_norm_width * '-',
) + '\n'
comp_error_string = header + comp_error_string
error_string += comp_error_string
# if error string then raise error with that string
if error_string:
header_line1 = 'Assert Check Partials failed for the following Components'
header_line2 = f'with absolute tolerance = {atol} and relative tolerance = {rtol}'
header_width = max(len(header_line1), len(header_line2))
header = '\n' + header_width * '=' + '\n'
header += header_line1 + '\n'
header += header_line2 + '\n'
header += header_width * '=' + '\n'
error_string = header + error_string
raise ValueError(error_string)
[docs]def assert_check_totals(totals_data, atol=1e-6, rtol=1e-6):
"""
Raise assertion if any entry from the return from check_totals is above a tolerance.
Parameters
----------
totals_data : Dict of Dicts of Tuples of Floats
First key:
is the (output, input) tuple of strings;
Second key:
is one of ['rel error', 'abs error', 'magnitude', 'fdstep'];
For 'rel error', 'abs error', 'magnitude' the value is: A tuple containing norms for
forward - fd, adjoint - fd, forward - adjoint.
atol : float
Absolute error. Default is 1e-6.
rtol : float
Relative error. Default is 1e-6.
"""
fails = []
incon_keys = set()
for key, dct in totals_data.items():
if 'inconsistent_keys' in dct:
incon_keys = dct['inconsistent_keys']
Jname = 'J_fwd' if 'J_fwd' in dct else 'J_rev'
try:
dct[Jname]
dct['J_fd']
except Exception as err:
raise err.__class__(f"For key {key}: {err}")
try:
for i in range(3):
erel, eabs = dct['rel error'][i], dct['abs error'][i]
if erel is not None and not np.isnan(erel):
if erel > rtol:
raise ValueError(f"rel tolerance of {erel} > allowed rel tolerance "
f"of {rtol}.")
if eabs is not None:
if eabs > atol:
raise ValueError(f"abs tolerance of {eabs} > allowed abs tolerance "
f"of {atol}.")
except ValueError as err:
fails.append((key, dct, err, Jname))
fail_list = []
if fails:
fail_list.extend(
[f"Totals differ for {key}:\nAnalytic:\n{dct[Jname]}\nFD:\n{dct['J_fd']}\n{err}"
for key, dct, err, Jname in fails])
if incon_keys:
ders = [f"{sof} wrt {swrt}" for sof, swrt in sorted(incon_keys)]
fail_list.append(f"During total derivative computation, the following partial derivatives "
"resulted in serial inputs that were inconsistent across processes: "
f"{ders}.")
if fails or incon_keys:
raise ValueError('\n\n'.join(fail_list))
[docs]def assert_no_approx_partials(system, include_self=True, recurse=True, method='any', excludes=None):
"""
Raise assertion error if any component within system is using approximated partials.
Parameters
----------
system : System
The system under which to search for approximated partials.
include_self : bool
If True, include this system in the iteration.
recurse : bool
If True, iterate over the whole tree under this system.
method : str
Specifically look for Components with this method of approx partials. Values can be
'cs', 'fd', or 'any'. 'any' means either 'cs' or 'fd'. The default is 'any'.
excludes : str, iter of str, or None
Glob patterns for pathnames to exclude from the check. Default is None, which
excludes nothing.
Raises
------
AssertionError
If a subsystem of group is found to be using approximated partials.
"""
if isinstance(excludes, str):
excludes = [excludes, ]
has_approx_partials = False
msg = 'The following components use approximated partials:\n'
for s in system.system_iter(include_self=include_self, recurse=recurse):
if isinstance(s, Component):
if excludes is not None and any(fnmatch(s.pathname, exclude) for exclude in excludes):
continue
if s._approx_schemes:
if method == 'any' or method in s._approx_schemes:
has_approx_partials = True
approx_partials = [(k, v['method'])
for k, v in s._declared_partials_patterns.items()
if 'method' in v and v['method']]
msg += ' ' + s.pathname + '\n'
for key, method in approx_partials:
key = (str(key[0]), str(key[1]))
msg += ' of={0:12s} wrt={1:12s} method={2:2s}\n'.format(key[0],
key[1],
method)
if has_approx_partials:
raise AssertionError(msg)
[docs]def assert_no_dict_jacobians(system, include_self=True, recurse=True):
"""
Raise an assertion error if any Group within system is found to be using dictionary jacobians.
Parameters
----------
system : System
The system under which to search for approximated partials.
include_self : bool
If True, include this system in the iteration.
recurse : bool
If True, iterate over the whole tree under this system.
Raises
------
AssertionError
If a subsystem of group is found to be using approximated partials.
"""
parts = ['The following groups use dictionary jacobians:\n']
for s in system.system_iter(include_self=include_self, recurse=recurse, typ=Group):
if isinstance(s._jacobian, DictionaryJacobian):
parts.append(' ' + s.pathname)
if len(parts) > 1:
raise AssertionError('\n'.join(parts))
[docs]def assert_near_equal(actual, desired, tolerance=1e-15):
"""
Check relative error.
Determine that the relative error between `actual` and `desired`
is within `tolerance`. If `desired` is zero, then use absolute error.
Can handle some data structures. Generates warnings for data types it cannot handle.
Parameters
----------
actual : float, array-like, dict
The value from the test.
desired : float, array-like, dict
The value expected.
tolerance : float
Maximum relative error ``(actual - desired) / desired``.
Returns
-------
float
The error.
"""
# Try to make similar things of the same type so they can be compared
# make arrays out of scalars
if type(actual) in [int, float, np.int64, np.float64, np.int32, np.complex128]:
actual = np.atleast_1d(actual)
if type(desired) in [int, float, np.int64, np.float64, np.int32, np.complex128]:
desired = np.atleast_1d(desired)
# Handle jax arrays, if available
if ArrayImpl is not None:
if isinstance(actual, ArrayImpl):
actual = np.atleast_1d(actual)
if isinstance(desired, ArrayImpl):
desired = np.atleast_1d(desired)
# if desired is numeric list or tuple, make ndarray out of it
if isinstance(actual, (list, tuple)):
actual = np.asarray(actual)
if isinstance(desired, (list, tuple)):
desired = np.asarray(desired)
# In case they are PromAbsDict and other dict-like objects
if isinstance(actual, dict) and type(actual) is not dict:
actual = dict(actual)
if isinstance(desired, dict) and type(desired) is not dict:
desired = dict(desired)
if type(actual) is not type(desired):
raise ValueError(f'actual {type(actual)}, desired {type(desired)} have different types')
if isinstance(actual, type) and isinstance(desired, type):
if actual != desired:
raise ValueError(
'actual type %s, and desired type %s are different' % (actual, desired))
return 0
# The code below can only handle these data types
_supported_types = [dict, set, str, bool, np.ndarray, type(None)]
if type(actual) not in _supported_types:
warnings.warn(
f"The function, assert_near_equal, does not support the actual value type: '"
f"{type(actual)}'.")
return 0
if type(desired) not in _supported_types:
warnings.warn(
f"The function, assert_near_equal, does not support the desired value type: '"
f"{type(actual)}'.")
return 0
if isinstance(actual, dict) and isinstance(desired, dict):
actual_keys = set(actual.keys())
desired_keys = set(desired.keys())
if actual_keys.symmetric_difference(desired_keys):
msg = 'Actual and desired keys differ. Actual extra keys: {}, Desired extra keys: {}'
actual_extra = actual_keys.difference(desired_keys)
desired_extra = desired_keys.difference(actual_keys)
raise KeyError(msg.format(actual_extra, desired_extra))
error = 0.
for key in actual_keys:
try:
new_error = assert_near_equal(actual[key], desired[key], tolerance)
error = max(error, new_error)
except ValueError as exception:
msg = '{}: '.format(key) + str(exception)
raise ValueError(msg) from None
except KeyError as exception:
msg = '{}: '.format(key) + str(exception)
raise KeyError(msg) from None
elif isinstance(actual, set) and isinstance(desired, set):
if actual.symmetric_difference(desired):
actual_extra = actual.difference(desired)
desired = desired.difference(actual)
raise KeyError("Actual and desired sets differ. "
f"Actual extra values: {actual_extra}, "
f"Desired extra values: {desired_extra}")
error = 0.
elif isinstance(actual, str) and isinstance(desired, str):
if actual != desired:
raise ValueError(
'actual %s, desired %s strings have different values' % (actual, desired))
error = 0.0
elif isinstance(actual, bool) and isinstance(desired, bool):
if actual != desired:
raise ValueError(
'actual %s, desired %s booleans have different values' % (actual, desired))
error = 0.0
elif actual is None and desired is None:
error = 0.0
# array values
elif isinstance(actual, np.ndarray) and isinstance(desired, np.ndarray):
if actual.dtype == object or desired.dtype == object:
if actual.dtype == object:
warnings.warn(
f"The function, assert_near_equal, does not support the actual value ndarray "
f"type of: '"
f"{type(actual.dtype)}'.")
if desired.dtype == object:
warnings.warn(
f"The function, assert_near_equal, does not support the desired value ndarray "
f"type of: '"
f"{type(desired.dtype)}'.")
error = 0.0
else:
actual = np.atleast_1d(actual)
desired = np.atleast_1d(desired)
if actual.shape != desired.shape:
raise ValueError(
'actual and desired have differing shapes.'
' actual {}, desired {}'.format(actual.shape, desired.shape))
# check to see if the entire array is made of floats. If not, loop through all values
if not np.all(np.isnan(actual) == np.isnan(desired)):
if actual.size == 1 and desired.size == 1:
raise ValueError('actual %s, desired %s' % (actual, desired))
else:
raise ValueError('actual and desired values have non-matching nan'
' values')
if np.linalg.norm(desired) == 0:
error = np.linalg.norm(actual)
else:
error = np.linalg.norm(actual - desired) / np.linalg.norm(desired)
if abs(error) > tolerance:
if actual.size < 10 and desired.size < 10:
raise ValueError('actual %s, desired %s, rel error %s, tolerance %s'
% (actual, desired, error, tolerance))
else:
raise ValueError('arrays do not match, rel error %.3e > tol (%.3e)' %
(error, tolerance))
elif isinstance(actual, tuple) and isinstance(desired, tuple):
error = 0.0
for act, des in zip(actual, desired):
new_error = assert_near_equal(act, des, tolerance)
error = max(error, new_error)
else:
raise ValueError(
'actual and desired have unexpected types: %s, %s' % (type(actual), type(desired)))
return error
[docs]def assert_equal_arrays(a1, a2):
"""
Check that two arrays are equal.
This is a simplified method useful when the arrays to be compared may
not be numeric. It simply compares the shapes of the two arrays and then
does a value by value comparison.
Parameters
----------
a1 : array
The first array to compare.
a2 : array
The second array to compare.
"""
assert a1.shape == a2.shape
for x, y in zip(a1.flat, a2.flat):
assert x == y
[docs]def assert_equal_numstrings(s1, s2, atol=1e-6, rtol=1e-6):
"""
Check that two strings containing numbers are equal after convering numerical parts to floats.
Parameters
----------
s1 : str
The first numeric string to compare.
s2 : str
The second numeric string to compare.
atol : float
Absolute error tolerance. Default is 1e-6.
rtol : float
Relative error tolerance. Default is 1e-6.
"""
assert snum_equal(s1, s2, atol=atol, rtol=rtol)
[docs]def skip_helper(msg):
"""
Raise a SkipTest.
Parameters
----------
msg : str
The skip messaage.
Raises
------
SkipTest
"""
raise unittest.SkipTest(msg)
[docs]class SkipParameterized(object):
"""
Replaces the parameterized class, skipping decorated test cases.
"""
[docs] @classmethod
def expand(cls, input, name_func=None, doc_func=None, skip_on_empty=False, **legacy):
"""
Decorate a test so that it raises a SkipTest.
Parameters
----------
input : iterable
Not used (part of parameterized API).
name_func : function
Not used (part of parameterized API).
doc_func : function
Not used (part of parameterized API).
skip_on_empty : bool
Not used (part of parameterized API).
**legacy : dict
Not used (part of parameterized API).
Returns
-------
function
The wrapper function.
"""
skip_msg = "requires 'parameterized' (install openmdao[test])"
def parameterized_expand_wrapper(f, instance=None):
"""
Wrap a function so that it raises a SkipTest.
f : function
Function to be wrapped.
instance : None
Not used (part of parameterized API).
Returns
-------
function
The wrapped function.
"""
return wraps(f)(lambda f: skip_helper(skip_msg))
return parameterized_expand_wrapper