Source code for openmdao.utils.assert_utils

Functions for making assertions about OpenMDAO Systems.

from fnmatch import fnmatch
import warnings
import unittest
from contextlib import contextmanager
from functools import wraps

import numpy as np

    from jaxlib.xla_extension import ArrayImpl
except ImportError:
    ArrayImpl = None

from openmdao.core.component import Component
from import Group
from openmdao.jacobians.dictionary_jacobian import DictionaryJacobian
from openmdao.utils.general_utils import pad_name
from openmdao.utils.om_warnings import reset_warning_registry
from openmdao.utils.mpi import MPI
from openmdao.utils.testing_utils import snum_equal

[docs]@contextmanager def assert_warning(category, msg, contains_msg=False, ranks=None): """ Context manager asserting that a warning is issued. Parameters ---------- category : class The class of the expected warning. msg : str The text of the expected warning. contains_msg : bool Set to True to check that the warning text contains msg, rather than checking equality. ranks : int or list of int, optional The global ranks on which the warning is expected. Yields ------ None Raises ------ AssertionError If the expected warning is not raised. """ with reset_warning_registry(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") yield if ranks is not None: if MPI is None: raise RuntimeError("ranks argument has been specified but MPI is not active") else: if not isinstance(ranks, list): ranks = [ranks] if MPI.COMM_WORLD.rank not in ranks: return for warn in w: if contains_msg: warn_clause = msg in str(warn.message) else: warn_clause = str(warn.message) == msg if (issubclass(warn.category, category) and warn_clause): break else: msg = f"Did not see expected {category.__name__}:\n{msg}" if w: ws = '\n'.join([str(ww.message) for ww in w]) categories = '\n'.join([str(ww.category.__name__) for ww in w]) msg += f"\nDid see warnings [{categories}]:\n{ws}" raise AssertionError(msg)
[docs]@contextmanager def assert_warnings(expected_warnings): """ Context manager asserting that expected warnings are issued. Parameters ---------- expected_warnings : iterable of (class, str) The category and text of the expected warnings. Yields ------ None Raises ------ AssertionError If all the expected warnings are not raised. """ with reset_warning_registry(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") yield for category, msg in expected_warnings: for warn in w: if (issubclass(warn.category, category) and str(warn.message) == msg): break else: raise AssertionError("Did not see expected %s: %s" % (category.__name__, msg))
[docs]@contextmanager def assert_no_warning(category, msg=None, contains=False): """ Context manager asserting that a warning is not issued. Parameters ---------- category : class The class of the warning. msg : str or None The text of the warning. If None then only the warning class will be checked. contains : bool If True, check that the warning text contains msg, rather than checking equality. Yields ------ None Raises ------ AssertionError If the warning is raised. """ with reset_warning_registry(): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") yield for warn in w: if issubclass(warn.category, category): if msg is None: raise AssertionError(f"Found warning: {category} {str(warn.message)}") elif contains: if msg in str(warn.message): raise AssertionError(f"Found warning: {category} containing '{msg}'") elif str(warn.message) == msg: raise AssertionError(f"Found warning: {category} {msg}")
[docs]def assert_check_partials(data, atol=1e-6, rtol=1e-6): """ Raise assertion if any entry from the return from check_partials is above a tolerance. Parameters ---------- data : dict of dicts of dicts First key: is the component name; Second key: is the (output, input) tuple of strings; Third key: is one of ['rel error', 'abs error', 'magnitude', 'J_fd', 'J_fwd', 'J_rev']; For 'rel error', 'abs error', 'magnitude' the value is: A tuple containing norms for forward - fd, adjoint - fd, forward - adjoint. For 'J_fd', 'J_fwd', 'J_rev' the value is: A numpy array representing the computed Jacobian for the three different methods of computation. atol : float Absolute error. Default is 1e-6. rtol : float Relative error. Default is 1e-6. """ error_string = '' absrel_header = 'abs/rel' wrt_header = '< output > wrt < variable >' norm_value_header = 'norm value' len_absrel_width = len(absrel_header) norm_types = ['fwd-fd', 'rev-fd', 'fd-rev'] len_norm_type_width = max(len(s) for s in norm_types) for comp in data: len_wrt_width = len(wrt_header) len_norm_width = len(norm_value_header) bad_derivs = [] inconsistent_derivs = set() # Find all derivatives whose errors exceed tolerance. # Also, size the output to precompute column extents. for key, pair_data in data[comp].items(): var, wrt = key for error_type, tolerance in [('abs error', atol), ('rel error', rtol), ]: actuals = pair_data[error_type] if not isinstance(actuals, list): actuals = [actuals] incon = pair_data.get('rank_inconsistent') if incon: inconsistent_derivs.add(key) for actual in actuals: for error_val, mode in zip(actual, norm_types): in_error = False if error_val is None: # Reverse derivatives only computed on matrix free comps. continue if not np.isnan(error_val): if not np.allclose(error_val, 0.0, atol=tolerance): if error_type == 'rel error' and mode == 'fwd-fd' and \ np.allclose(pair_data['J_fwd'], 0.0, atol=atol) and \ np.allclose(pair_data['J_fd'], 0.0, atol=atol): # Special case: both fd and fwd are really tiny, so we want to # ignore the rather large relative errors. in_error = False else: # This is a bona-fide error. in_error = True elif error_type == 'abs error' and mode == 'fwd-fd': # Either analytic or approximated derivatives contain a NaN. in_error = True if in_error: wrt_string = f'{var} wrt {wrt}' norm_string = str(error_val) bad_derivs.append((wrt_string, norm_string, error_type, mode)) len_wrt_width = max(len_wrt_width, len(wrt_string)) len_norm_width = max(len_norm_width, len(norm_string)) if bad_derivs or inconsistent_derivs: comp_error_string = '' if bad_derivs: for wrt_string, norm_string, error_type, mode in bad_derivs: err_msg = '{0} | {1} | {2} | {3}'.format( pad_name(wrt_string, len_wrt_width), pad_name(error_type.split()[0], len_absrel_width), pad_name(mode, len_norm_type_width), pad_name(norm_string, len_norm_width)) + '\n' comp_error_string += err_msg if inconsistent_derivs: comp_error_string += ( "\nInconsistent derivs across processes for keys: " f"{sorted(inconsistent_derivs)}.\nCheck that distributed outputs are properly " "reduced when computing\nderivatives of serial inputs.") name_header = 'Component: {}\n'.format(comp) len_name_header = len(name_header) header = len_name_header * '-' + '\n' header += name_header header += len_name_header * '-' + '\n' header += '{0} | {1} | {2} | {3}'.format( pad_name(wrt_header, len_wrt_width), pad_name(absrel_header, len_absrel_width), pad_name('norm', len_norm_type_width), pad_name(norm_value_header, len_norm_width), ) + '\n' header += '{0} | {1} | {2} | {3}'.format( len_wrt_width * '-', len_absrel_width * '-', len_norm_type_width * '-', len_norm_width * '-', ) + '\n' comp_error_string = header + comp_error_string error_string += comp_error_string # if error string then raise error with that string if error_string: header_line1 = 'Assert Check Partials failed for the following Components' header_line2 = f'with absolute tolerance = {atol} and relative tolerance = {rtol}' header_width = max(len(header_line1), len(header_line2)) header = '\n' + header_width * '=' + '\n' header += header_line1 + '\n' header += header_line2 + '\n' header += header_width * '=' + '\n' error_string = header + error_string raise ValueError(error_string)
[docs]def assert_check_totals(totals_data, atol=1e-6, rtol=1e-6): """ Raise assertion if any entry from the return from check_totals is above a tolerance. Parameters ---------- totals_data : Dict of Dicts of Tuples of Floats First key: is the (output, input) tuple of strings; Second key: is one of ['rel error', 'abs error', 'magnitude', 'fdstep']; For 'rel error', 'abs error', 'magnitude' the value is: A tuple containing norms for forward - fd, adjoint - fd, forward - adjoint. atol : float Absolute error. Default is 1e-6. rtol : float Relative error. Default is 1e-6. """ fails = [] incon_keys = set() for key, dct in totals_data.items(): if 'inconsistent_keys' in dct: incon_keys = dct['inconsistent_keys'] Jname = 'J_fwd' if 'J_fwd' in dct else 'J_rev' try: dct[Jname] dct['J_fd'] except Exception as err: raise err.__class__(f"For key {key}: {err}") try: for i in range(3): erel, eabs = dct['rel error'][i], dct['abs error'][i] if erel is not None and not np.isnan(erel): if erel > rtol: raise ValueError(f"rel tolerance of {erel} > allowed rel tolerance " f"of {rtol}.") if eabs is not None: if eabs > atol: raise ValueError(f"abs tolerance of {eabs} > allowed abs tolerance " f"of {atol}.") except ValueError as err: fails.append((key, dct, err, Jname)) fail_list = [] if fails: fail_list.extend( [f"Totals differ for {key}:\nAnalytic:\n{dct[Jname]}\nFD:\n{dct['J_fd']}\n{err}" for key, dct, err, Jname in fails]) if incon_keys: ders = [f"{sof} wrt {swrt}" for sof, swrt in sorted(incon_keys)] fail_list.append(f"During total derivative computation, the following partial derivatives " "resulted in serial inputs that were inconsistent across processes: " f"{ders}.") if fails or incon_keys: raise ValueError('\n\n'.join(fail_list))
[docs]def assert_no_approx_partials(system, include_self=True, recurse=True, method='any', excludes=None): """ Raise assertion error if any component within system is using approximated partials. Parameters ---------- system : System The system under which to search for approximated partials. include_self : bool If True, include this system in the iteration. recurse : bool If True, iterate over the whole tree under this system. method : str Specifically look for Components with this method of approx partials. Values can be 'cs', 'fd', or 'any'. 'any' means either 'cs' or 'fd'. The default is 'any'. excludes : str, iter of str, or None Glob patterns for pathnames to exclude from the check. Default is None, which excludes nothing. Raises ------ AssertionError If a subsystem of group is found to be using approximated partials. """ if isinstance(excludes, str): excludes = [excludes, ] has_approx_partials = False msg = 'The following components use approximated partials:\n' for s in system.system_iter(include_self=include_self, recurse=recurse): if isinstance(s, Component): if excludes is not None and any(fnmatch(s.pathname, exclude) for exclude in excludes): continue if s._approx_schemes: if method == 'any' or method in s._approx_schemes: has_approx_partials = True approx_partials = [(k, v['method']) for k, v in s._declared_partials_patterns.items() if 'method' in v and v['method']] msg += ' ' + s.pathname + '\n' for key, method in approx_partials: key = (str(key[0]), str(key[1])) msg += ' of={0:12s} wrt={1:12s} method={2:2s}\n'.format(key[0], key[1], method) if has_approx_partials: raise AssertionError(msg)
[docs]def assert_no_dict_jacobians(system, include_self=True, recurse=True): """ Raise an assertion error if any Group within system is found to be using dictionary jacobians. Parameters ---------- system : System The system under which to search for approximated partials. include_self : bool If True, include this system in the iteration. recurse : bool If True, iterate over the whole tree under this system. Raises ------ AssertionError If a subsystem of group is found to be using approximated partials. """ parts = ['The following groups use dictionary jacobians:\n'] for s in system.system_iter(include_self=include_self, recurse=recurse, typ=Group): if isinstance(s._jacobian, DictionaryJacobian): parts.append(' ' + s.pathname) if len(parts) > 1: raise AssertionError('\n'.join(parts))
[docs]def assert_near_equal(actual, desired, tolerance=1e-15, tol_type='rel'): """ Check relative error. Determine that the relative error between `actual` and `desired` is within `tolerance`. If `desired` is zero, then use absolute error. Can handle some data structures. Generates warnings for data types it cannot handle. Parameters ---------- actual : float, array-like, dict The value from the test. desired : float, array-like, dict The value expected. tolerance : float Maximum relative or absolute error. For relative tolerance: ``(actual - desired) / desired``. For absolute tolerance: ``(actual - desired)``. tol_type : {'rel', 'abs'} Type of error to use: 'rel' for relative error, 'abs' for absolute error. Default is set to 'rel'. Returns ------- float The error. """ # Try to make similar things of the same type so they can be compared # make arrays out of scalars if type(actual) in [int, float, np.int64, np.float64, np.int32, np.complex128]: actual = np.atleast_1d(actual) if type(desired) in [int, float, np.int64, np.float64, np.int32, np.complex128]: desired = np.atleast_1d(desired) # Handle jax arrays, if available if ArrayImpl is not None: if isinstance(actual, ArrayImpl): actual = np.atleast_1d(actual) if isinstance(desired, ArrayImpl): desired = np.atleast_1d(desired) # if desired is numeric list or tuple, make ndarray out of it if isinstance(actual, (list, tuple)): actual = np.asarray(actual) if isinstance(desired, (list, tuple)): desired = np.asarray(desired) # In case they are PromAbsDict and other dict-like objects if isinstance(actual, dict) and type(actual) is not dict: actual = dict(actual) if isinstance(desired, dict) and type(desired) is not dict: desired = dict(desired) if type(actual) is not type(desired): raise ValueError(f'actual {type(actual)}, desired {type(desired)} have different types') if isinstance(actual, type) and isinstance(desired, type): if actual != desired: raise ValueError( 'actual type %s, and desired type %s are different' % (actual, desired)) return 0 # The code below can only handle these data types _supported_types = [dict, set, str, bool, np.ndarray, type(None)] if type(actual) not in _supported_types: warnings.warn( f"The function, assert_near_equal, does not support the actual value type: '" f"{type(actual)}'.") return 0 if type(desired) not in _supported_types: warnings.warn( f"The function, assert_near_equal, does not support the desired value type: '" f"{type(actual)}'.") return 0 if isinstance(actual, dict) and isinstance(desired, dict): actual_keys = set(actual.keys()) desired_keys = set(desired.keys()) if actual_keys.symmetric_difference(desired_keys): msg = 'Actual and desired keys differ. Actual extra keys: {}, Desired extra keys: {}' actual_extra = actual_keys.difference(desired_keys) desired_extra = desired_keys.difference(actual_keys) raise KeyError(msg.format(actual_extra, desired_extra)) error = 0. for key in actual_keys: try: new_error = assert_near_equal(actual[key], desired[key], tolerance, tol_type) error = max(error, new_error) except ValueError as exception: msg = '{}: '.format(key) + str(exception) raise ValueError(msg) from None except KeyError as exception: msg = '{}: '.format(key) + str(exception) raise KeyError(msg) from None elif isinstance(actual, set) and isinstance(desired, set): if actual.symmetric_difference(desired): actual_extra = actual.difference(desired) desired = desired.difference(actual) raise KeyError("Actual and desired sets differ. " f"Actual extra values: {actual_extra}, " f"Desired extra values: {desired_extra}") error = 0. elif isinstance(actual, str) and isinstance(desired, str): if actual != desired: raise ValueError( 'actual %s, desired %s strings have different values' % (actual, desired)) error = 0.0 elif isinstance(actual, bool) and isinstance(desired, bool): if actual != desired: raise ValueError( 'actual %s, desired %s booleans have different values' % (actual, desired)) error = 0.0 elif actual is None and desired is None: error = 0.0 # array values elif isinstance(actual, np.ndarray) and isinstance(desired, np.ndarray): if actual.dtype == object or desired.dtype == object: if actual.dtype == object: warnings.warn( f"The function, assert_near_equal, does not support the actual value ndarray " f"type of: '" f"{type(actual.dtype)}'.") if desired.dtype == object: warnings.warn( f"The function, assert_near_equal, does not support the desired value ndarray " f"type of: '" f"{type(desired.dtype)}'.") error = 0.0 else: actual = np.atleast_1d(actual) desired = np.atleast_1d(desired) if actual.shape != desired.shape: raise ValueError( 'actual and desired have differing shapes.' ' actual {}, desired {}'.format(actual.shape, desired.shape)) # check to see if the entire array is made of floats. If not, loop through all values if not np.all(np.isnan(actual) == np.isnan(desired)): if actual.size == 1 and desired.size == 1: raise ValueError('actual %s, desired %s' % (actual, desired)) else: raise ValueError('actual and desired values have non-matching nan' ' values') if np.linalg.norm(desired) == 0 or tol_type == 'abs': error = np.linalg.norm(actual - desired) else: error = np.linalg.norm(actual - desired) / np.linalg.norm(desired) if abs(error) > tolerance: if actual.size < 10 and desired.size < 10: raise ValueError('actual %s, desired %s, %s error %s, tolerance %s' % (actual, desired, tol_type, error, tolerance)) else: raise ValueError('arrays do not match, rel error %.3e > tol (%.3e)' % (error, tolerance)) elif isinstance(actual, tuple) and isinstance(desired, tuple): error = 0.0 for act, des in zip(actual, desired): new_error = assert_near_equal(act, des, tolerance, tol_type) error = max(error, new_error) else: raise ValueError( 'actual and desired have unexpected types: %s, %s' % (type(actual), type(desired))) return error
[docs]def assert_equal_arrays(a1, a2): """ Check that two arrays are equal. This is a simplified method useful when the arrays to be compared may not be numeric. It simply compares the shapes of the two arrays and then does a value by value comparison. Parameters ---------- a1 : array The first array to compare. a2 : array The second array to compare. """ assert a1.shape == a2.shape for x, y in zip(a1.flat, a2.flat): assert x == y
[docs]def assert_equal_numstrings(s1, s2, atol=1e-6, rtol=1e-6): """ Check that two strings containing numbers are equal after convering numerical parts to floats. Parameters ---------- s1 : str The first numeric string to compare. s2 : str The second numeric string to compare. atol : float Absolute error tolerance. Default is 1e-6. rtol : float Relative error tolerance. Default is 1e-6. """ assert snum_equal(s1, s2, atol=atol, rtol=rtol)
[docs]def skip_helper(msg): """ Raise a SkipTest. Parameters ---------- msg : str The skip messaage. Raises ------ SkipTest """ raise unittest.SkipTest(msg)
[docs]class SkipParameterized(object): """ Replaces the parameterized class, skipping decorated test cases. """
[docs] @classmethod def expand(cls, input, name_func=None, doc_func=None, skip_on_empty=False, **legacy): """ Decorate a test so that it raises a SkipTest. Parameters ---------- input : iterable Not used (part of parameterized API). name_func : function Not used (part of parameterized API). doc_func : function Not used (part of parameterized API). skip_on_empty : bool Not used (part of parameterized API). **legacy : dict Not used (part of parameterized API). Returns ------- function The wrapper function. """ skip_msg = "requires 'parameterized' (install openmdao[test])" def parameterized_expand_wrapper(f, instance=None): """ Wrap a function so that it raises a SkipTest. f : function Function to be wrapped. instance : None Not used (part of parameterized API). Returns ------- function The wrapped function. """ return wraps(f)(lambda f: skip_helper(skip_msg)) return parameterized_expand_wrapper