"""A module containing various configuration checks for an OpenMDAO Problem."""
from collections import defaultdict
from packaging.version import Version
import pathlib
from io import StringIO
import numpy as np
import pickle
from openmdao.core.group import Group
from openmdao.core.component import Component
from openmdao.core.implicitcomponent import ImplicitComponent
from openmdao.utils.graph_utils import get_sccs_topo
from openmdao.utils.logger_utils import get_logger, TestLogger
from openmdao.utils.mpi import MPI
from openmdao.utils.hooks import _register_hook
from openmdao.utils.general_utils import printoptions
from openmdao.utils.units import _has_val_mismatch
from openmdao.utils.file_utils import _load_and_exec, text2html
from openmdao.utils.om_warnings import issue_warning, SetupWarning
from openmdao.utils.reports_system import register_report
_UNSET = object()
# numpy default print options changed in 1.14
if Version(np.__version__) >= Version("1.14"):
_npy_print_opts = {'legacy': '1.13'}
else:
_npy_print_opts = {}
def _check_cycles(group, infos=None):
"""
Report any cycles to the logger.
Parameters
----------
group : <Group>
The Group being checked for dataflow issues
infos : list
List to collect informational messages.
Returns
-------
list
List of cycles, with subsystem names sorted in execution order.
"""
graph = group.compute_sys_graph(comps_only=False)
sccs = get_sccs_topo(graph)
cycles = [sorted(s, key=lambda n: group._subsystems_allprocs[n].index)
for s in sccs if len(s) > 1]
if cycles and infos is not None:
infos.append(f" Group '{group.pathname}' has the following cycles:")
for cycle in cycles:
infos.append(f" {cycle}")
infos.append('')
return cycles
def _check_ubcs(group, warnings):
"""
Report any 'used before calculated' Systems to the logger.
Parameters
----------
group : <Group>
The Group being checked for dataflow issues
warnings : list
List to collect warning messages.
"""
out_of_order = group._check_order(reorder=False, recurse=False)
for syspath, conns in out_of_order.items():
prefix = f" In System '{syspath}', subsystem " if syspath else " System "
for tgt, srcs in conns.items():
warnings.append(f"{prefix}'{tgt}' executes out-of-order "
f"with respect to its source systems {srcs}\n")
parallel_solvers = {}
allsubs = group._subsystems_allprocs
for sub, _ in allsubs.values():
if hasattr(sub, '_mpi_proc_allocator') and sub._mpi_proc_allocator.parallel:
parallel_solvers[sub.name] = sub.nonlinear_solver.SOLVER
if parallel_solvers:
_check_parallel_solvers(group, parallel_solvers)
def _check_parallel_solvers(group, parallel_solvers):
"""
Report any parallel groups that don't have the proper solver.
Parameters
----------
group : <Group>
The Group being checked.
parallel_solvers : dict
Dictionary of parallel solvers keyed by subsystem names.
"""
glen = len(group.pathname.split('.')) if group.pathname else 0
for tgt_abs, src_abs in group._conn_global_abs_in2out.items():
iparts = tgt_abs.split('.')
oparts = src_abs.split('.')
src_sys = oparts[glen]
tgt_sys = iparts[glen]
hierarchy_check = oparts[glen + 1] == iparts[glen + 1]
if (src_sys in parallel_solvers and tgt_sys in parallel_solvers and
(parallel_solvers[src_sys] not in ["NL: NLBJ", "NL: Newton", "NL: BROYDEN"]) and
src_sys == tgt_sys and
not hierarchy_check):
issue_warning("Need to attach NonlinearBlockJac, NewtonSolver, or BroydenSolver to "
f"'{src_sys}' when connecting components inside parallel groups",
category=SetupWarning)
def _check_cycles_prob(prob, logger):
"""
Report any cycles.
Parameters
----------
prob : <Problem>
The Problem being checked for cycles.
logger : object
The object that manages logging output.
"""
infos = ["The following groups contain cycles:"]
for group in prob.model.system_iter(include_self=True, recurse=True, typ=Group):
_check_cycles(group, infos)
if len(infos) > 1:
logger.info(infos[0])
for i in range(1, len(infos)):
logger.info(infos[i])
def _check_ubcs_prob(prob, logger):
"""
Report any out of order Systems.
Parameters
----------
prob : <Problem>
The Problem being checked for dataflow issues.
logger : object
The object that manages logging output.
"""
warnings = ["The following systems are executed out-of-order:\n"]
for group in prob.model.system_iter(include_self=True, recurse=True, typ=Group):
_check_ubcs(group, warnings)
if len(warnings) > 1:
logger.warning(''.join(warnings[:1] + sorted(warnings[1:])))
def _check_dup_comp_inputs(problem, logger):
"""
Issue a logger warning if any components have multiple inputs that share the same source.
Parameters
----------
problem : <Problem>
The problem being checked.
logger : object
The object that manages logging output.
"""
if isinstance(problem.model, Component):
return
input_srcs = problem.model._conn_global_abs_in2out
src2inps = defaultdict(list)
for inp, src in input_srcs.items():
src2inps[src].append(inp)
msgs = []
for src, inps in src2inps.items():
comps = defaultdict(list)
for inp in inps:
comp, vname = inp.rsplit('.', 1)
comps[comp].append(vname)
dups = sorted([(c, v) for c, v in comps.items() if len(v) > 1], key=lambda x: x[0])
if dups:
for comp, vnames in dups:
msgs.append(" %s has inputs %s connected to %s\n" % (comp, sorted(vnames), src))
if msgs:
msg = ["The following components have multiple inputs connected to the same source, ",
"which can introduce unnecessary data transfer overhead:\n"]
msg += sorted(msgs)
logger.warning(''.join(msg))
def _trim_str(obj, size):
"""
Truncate given string if it's longer than the given size.
For arrays, use the norm if the size is exceeded.
Parameters
----------
obj : object
Object to be stringified and trimmed.
size : int
Max allowable size of the returned string.
Returns
-------
str
The trimmed string.
"""
if isinstance(obj, np.ndarray):
with printoptions(**_npy_print_opts):
s = str(obj)
else:
s = str(obj)
if len(s) > size:
if isinstance(obj, np.ndarray) and np.issubdtype(obj.dtype, np.floating):
s = 'shape={}, norm={:<.3}'.format(obj.shape, np.linalg.norm(obj))
else:
s = s[:size - 4] + ' ...'
return s
def _list_has_val_mismatch(discretes, names, units, vals):
"""
Return True if any of the given values don't match, subject to unit conversion.
Parameters
----------
discretes : set-like
Set of discrete variable names.
names : list
List of variable names.
units : list
List of units corresponding to names.
vals : list
List of values corresponding to names.
Returns
-------
bool
True if a mismatch was found, otherwise False.
"""
if len(names) < 2:
return False
uset = set(units)
if '' in uset and len(uset) > 1:
# at least one case has no units and at least one does, so there must be a mismatch
return True
u0 = v0 = _UNSET
for n, u, v in zip(names, units, vals):
if n in discretes:
continue
if u0 is _UNSET:
u0 = u
v0 = v
elif _has_val_mismatch(u0, v0, u, v):
return True
return False
def _check_hanging_inputs(problem, logger):
"""
Issue a logger warning if any model inputs are not connected.
If an input is declared as a design variable, it is considered to be connected. Promoted
inputs are shown alongside their corresponding absolute names.
Parameters
----------
problem : <Problem>
The problem being checked.
logger : object
The object that manages logging output.
"""
model = problem.model
if isinstance(model, Component):
return
conns = model._conn_global_abs_in2out
abs2prom = model._var_allprocs_abs2prom['input']
desvar = problem.driver._designvars
unconns = []
for abs_tgt, src in conns.items():
if src.startswith('_auto_ivc.'):
prom_tgt = abs2prom[abs_tgt]
# Ignore inputs that are declared as design vars.
if desvar and prom_tgt in desvar:
continue
unconns.append((prom_tgt, abs_tgt))
if unconns:
msg = ["The following inputs are not connected:\n"]
for prom_tgt, abs_tgt in sorted(unconns):
msg.append(f' {prom_tgt} ({abs_tgt})\n')
logger.warning(''.join(msg))
def _check_comp_has_no_outputs(problem, logger):
"""
Issue a logger warning if any components do not have any outputs.
Parameters
----------
problem : <Problem>
The problem being checked.
logger : object
The object that manages logging output.
"""
msg = []
for comp in problem.model.system_iter(include_self=True, recurse=True, typ=Component):
if len(list(comp.abs_name_iter('output', local=False, discrete=True))) == 0:
msg.append(" %s\n" % comp.pathname)
if msg:
logger.warning(''.join(["The following Components do not have any outputs:\n"] + msg))
def _check_auto_ivc_warnings(problem, logger):
"""
Issue a logger warning if any components have conflicting attributes.
Parameters
----------
problem : <Problem>
The problem being checked.
logger : object
The object that manages logging output.
"""
if hasattr(problem.model, "_auto_ivc_warnings"):
for i in problem.model._auto_ivc_warnings:
logger.warning(i)
def _check_system_configs(problem, logger):
"""
Perform any system specific configuration checks.
Parameters
----------
problem : <Problem>
The problem being checked.
logger : object
The object that manages logging output.
"""
for system in problem.model.system_iter(include_self=True, recurse=True):
system.check_config(logger)
def _has_ancestor_solver(path, solvers):
"""
Return True if the given path has an ancestor with a solver.
Parameters
----------
path : str
The path to the system being checked.
solvers : dict
Dictionary of solvers keyed by system pathname.
Returns
-------
bool
True if the given path has an ancestor with a solver.
"""
while path:
path, _, _ = path.rpartition('.')
if path in solvers:
return True
return False
def _check_solvers(problem, logger):
"""
Search over all solvers and warn about unsupported configurations.
Report any implicit component that does not implement solve_nonlinear and
solve_linear or have an iterative nonlinear and linear solver upstream of it.
Report any cycles that do not have an iterative nonlinear solver and either
an iterative linear solver or a DirectSolver upstream of it.
Parameters
----------
problem : <Problem>
The problem being checked.
logger : object
The object that manages logging output.
"""
from openmdao.core.group import Group, iter_solver_info
from openmdao.core.implicitcomponent import ImplicitComponent
has_nl_solver = {}
has_lin_solver = {}
group = problem.model
lst = []
for tup in group._sys_tree_visitor(iter_solver_info,
predicate=lambda s: isinstance(s,
(Group, ImplicitComponent))):
path, pathclass, sccs, lnslv, nlslv, lnmaxiter, nlmaxiter, missing, isgrp, \
nl_cansolve, lin_cansolve = tup
if isgrp:
lst.append(tup)
if not isgrp or sccs: # a group with cycles or an implicit component
missing = []
if not nl_cansolve and not _has_ancestor_solver(path, has_nl_solver):
missing.append('nonlinear')
if not lin_cansolve and not _has_ancestor_solver(path, has_lin_solver):
missing.append('linear')
if missing:
missing = ' or '.join(missing)
if isgrp:
sccs = [tuple(sorted(s)) for s in sccs]
msg = (f"Group '{path}' contains cycles {sccs}, but does not have an iterative "
f"{missing} solver.")
else:
fncs = []
if 'nonlinear' in missing and nlslv != 'solve_nonlinear':
fncs.append('solve_nonlinear')
if 'linear' in missing and lnslv != 'solve_linear':
fncs.append('solve_linear')
fncs = ' or '.join(fncs)
msg = (f"{pathclass} '{path}' contains implicit variables but does "
f"not implement {fncs} or have an iterative {missing} solver.")
logger.warning(msg)
if isgrp and lin_cansolve:
has_lin_solver[path] = (lnslv, lnmaxiter)
if isgrp and nl_cansolve:
has_nl_solver[path] = (nlslv, nlmaxiter)
seen = set()
lines = []
for tup in lst:
path, pathclass, sccs, lnslv, nlslv, lnmaxiter, nlmaxiter, missing, _, \
_, _ = tup
if sccs:
if pathclass in seen:
continue
if missing == 0 and len(sccs) == 1:
continue # don't show groups without sub-cycles
seen.add(pathclass)
lines.append(f"'{path}' ({pathclass}) NL: {nlslv} (maxiter={nlmaxiter}), LN: "
f"{lnslv} (maxiter={lnmaxiter}):")
for i, scc in enumerate(sccs):
lines.append(f" Cycle {i}: {sorted(scc)}")
if missing:
lines.append(f" Number of non-cycle subsystems: {missing}")
lines.append('')
if lines:
final = []
final.append("The following groups contain sub-cycles. Performance and/or convergence "
"may improve")
final.append("if these sub-cycles are solved separately in their own group.")
final.append('')
final.extend(lines)
logger.warning('\n'.join(final))
def _check_missing_recorders(problem, logger):
"""
Check to see if there are any recorders of any type on the Problem.
Parameters
----------
problem : <Problem>
The problem being checked.
logger : object
The object that manages logging output.
"""
# Look for a Problem recorder
if problem._rec_mgr._recorders:
return
# Look for Driver recorder
if problem.driver._rec_mgr._recorders:
return
# Look for System and Solver recorders
for system in problem.model.system_iter(include_self=True, recurse=True):
if system._rec_mgr._recorders:
return
if system.nonlinear_solver and system.nonlinear_solver._rec_mgr._recorders:
return
if system.linear_solver and system.linear_solver._rec_mgr._recorders:
return
msg = "The Problem has no recorder of any kind attached"
logger.warning(msg)
def _check_unserializable_options(problem, logger, check_recordable=True):
"""
Check if there are any options that are not serializable, and therefore won't be recorded.
Parameters
----------
problem : <Problem>
The problem being checked.
logger : object
The object that manages logging output.
check_recordable : bool
If False, warn about all unserializable options.
If True, warn only about unserializable options that do not have 'recordable' set to False.
"""
from openmdao.recorders.case_recorder import PICKLE_VER
def _check_opts(obj, name=None):
if obj:
for key, val in obj.options.items():
try:
pickle.dumps(val, PICKLE_VER)
except Exception:
name_str = name + " " if name else ""
if obj.options._dict[key]['recordable']:
msg = f"{obj.msginfo}: {name_str}option '{key}' is not serializable " \
"(cannot be pickled) but 'recordable=False' has not been set. " \
f"No options will be recorded for this {obj.__class__.__name__} " \
"unless 'recordable' is set to False for this option."
logger.warning(msg)
elif not check_recordable:
msg = f"{obj.msginfo}: {name_str}option '{key}' is not serializable " \
"(cannot be pickled) and will not be recorded."
logger.warning(msg)
# check options for all for Systems and Solvers
for system in problem.model.system_iter(include_self=True, recurse=True):
_check_opts(system)
_check_opts(system.linear_solver, 'linear_solver')
nl = system.nonlinear_solver
if nl:
_check_opts(nl, 'nonlinear_solver')
for name in ('linear_solver', 'linesearch'):
_check_opts(getattr(nl, name, None), name)
def _check_all_unserializable_options(problem, logger):
"""
Check if there are any options that are not serializable, and therefore won't be recorded.
Parameters
----------
problem : <Problem>
The problem being checked.
logger : object
The object that manages logging output.
"""
_check_unserializable_options(problem, logger, False)
def _get_promoted_connected_ins(g):
"""
Find all inputs that are promoted above the level where they are explicitly connected.
Parameters
----------
g : Group
Starting Group.
Returns
-------
defaultdict
Absolute input name keyed to [promoting_groups, manually_connecting_groups]
"""
prom2abs_list = g._var_allprocs_prom2abs_list['input']
abs2prom_in = g._var_abs2prom['input']
prom_conn_ins = defaultdict(lambda: ([], []))
for prom_in in g._manual_connections:
for abs_in in prom2abs_list[prom_in]:
prom_conn_ins[abs_in][1].append((prom_in, g.pathname))
for subsys in g._subgroups_myproc:
sub_prom_conn_ins = _get_promoted_connected_ins(subsys)
for n, tup in sub_prom_conn_ins.items():
proms, mans = tup
mytup = prom_conn_ins[n]
mytup[0].extend(proms)
mytup[1].extend(mans)
sub_abs2prom_in = subsys._var_abs2prom['input']
for inp, sub_prom_inp in sub_abs2prom_in.items():
if abs2prom_in[inp] == sub_prom_inp: # inp is promoted up from sub
if inp in sub_prom_conn_ins and len(sub_prom_conn_ins[inp][1]) > 0:
prom_conn_ins[inp][0].append(subsys.pathname)
return prom_conn_ins
def _check_explicitly_connected_promoted_inputs(problem, logger):
"""
Check for any inputs that are explicitly connected AND promoted above their connection group.
Parameters
----------
problem : <Problem>
The problem being checked.
logger : object
The object that manages logging output.
"""
prom_conn_ins = _get_promoted_connected_ins(problem.model)
for inp, lst in prom_conn_ins.items():
proms, mans = lst
if proms:
# there can only be one manual connection (else an exception would've been raised)
man_prom, man_group = mans[0]
if len(proms) > 1:
lst = [p for p in proms if p == man_group or man_group.startswith(p + '.')]
s = "groups %s" % sorted(lst)
else:
s = "group '%s'" % proms[0]
logger.warning("Input '%s' was explicitly connected in group '%s' as '%s', but was "
"promoted up from %s." % (inp, man_group, man_prom, s))
# Dict of all checks by name, mapped to the corresponding function that performs the check
# Each function must be of the form f(problem, logger).
_default_checks = {
'out_of_order': _check_ubcs_prob,
'system': _check_system_configs,
'solvers': _check_solvers,
'dup_inputs': _check_dup_comp_inputs,
'missing_recorders': _check_missing_recorders,
'unserializable_options': _check_unserializable_options,
'comp_has_no_outputs': _check_comp_has_no_outputs,
'auto_ivc_warnings': _check_auto_ivc_warnings,
}
_all_checks = _default_checks.copy()
_all_checks.update({
'cycles': _check_cycles_prob,
'unconnected_inputs': _check_hanging_inputs,
'promotions': _check_explicitly_connected_promoted_inputs,
'all_unserializable_options': _check_all_unserializable_options,
})
_all_non_redundant_checks = _all_checks.copy()
del _all_non_redundant_checks['unserializable_options']
#
# Command line interface functions
#
def _check_config_setup_parser(parser):
"""
Set up the openmdao subparser for the 'openmdao check' command.
Parameters
----------
parser : argparse subparser
The parser we're adding options to.
"""
parser.add_argument('file', nargs=1, help='Python file containing the model')
parser.add_argument('-o', action='store', dest='outfile', help='output file')
parser.add_argument('-p', '--problem', action='store', dest='problem', help='Problem name')
parser.add_argument('-c', action='append', dest='checks', default=[],
help='Only perform specific check(s). Default checks are: %s. '
'Other available checks are: %s' %
(sorted(_default_checks), sorted(set(_all_checks) - set(_default_checks))))
def _get_checks(checks):
if checks is True:
checks = sorted(_default_checks)
elif not checks:
checks = ()
elif 'all' in checks:
checks = sorted(_all_non_redundant_checks)
return checks
class _Log2File(object):
def __init__(self, f):
self.f = f
def info(self, msg):
self.f.write(msg)
self.f.write('\n')
error = info
warning = info
debug = info
def _run_check_report(prob):
s = StringIO()
chk = prob._check if prob._check is not None else True
for c in _get_checks(chk):
if c not in _all_checks:
print(f"WARNING: '{c}' is not a recognized check. Available checks are: "
f"{sorted(_all_checks)}")
continue
print('-' * 30, f'Checking {c}', '-' * 30, file=s)
_all_checks[c](prob, _Log2File(s))
output = s.getvalue()
if output:
path = pathlib.Path(prob.get_reports_dir() / 'checks.html')
with open(path, 'w') as f:
f.write(text2html(output))
# entry point for check report
def _check_report_register():
register_report('checks', _run_check_report, 'Config checks', 'Problem',
'final_setup', 'post')
def _check_config_cmd(options, user_args):
"""
Return the post_setup hook function for 'openmdao check'.
Parameters
----------
options : argparse Namespace
Command line options.
user_args : list of str
Args to be passed to the user script.
Returns
-------
function
The post-setup hook function.
"""
def _check_config(prob):
if not options.checks:
options.checks = sorted(_default_checks)
elif 'all' in options.checks:
options.checks = sorted(_all_non_redundant_checks)
if not MPI or prob.comm.rank == 0:
if options.outfile is None:
logger = get_logger('check_config', out_stream='stdout',
out_file=None, use_format=True)
else:
logger = get_logger('check_config', out_file=options.outfile, use_format=True)
else:
# if not rank 0, don't display anything, but still do the config check to prevent
# any MPI hangs due to collective calls
logger = TestLogger()
prob.check_config(logger, options.checks)
# register the hook
_register_hook('final_setup', class_name='Problem', inst_id=options.problem, post=_check_config,
exit=True)
_load_and_exec(options.file[0], user_args)
[docs]
def check_allocate_complex_ln(group, under_cs):
"""
Return True if linear vector should be complex.
This happens when a solver needs derivatives under complex step.
Parameters
----------
group : <Group>
Group to be checked.
under_cs : bool
Flag indicates if complex vectors were allocated in a containing Group or were force
allocated in setup.
Returns
-------
bool
True if linear vector should be complex.
"""
under_cs |= 'cs' in group._approx_schemes
if under_cs and group.nonlinear_solver is not None and \
group.nonlinear_solver.supports['gradients']:
return True
for sub, _ in group._subsystems_allprocs.values():
if isinstance(sub, Group) and check_allocate_complex_ln(sub, under_cs):
return True
elif isinstance(sub, ImplicitComponent):
if sub.nonlinear_solver is not None and sub.nonlinear_solver.supports['gradients']:
# Special case, gradient-supporting solver in an ImplicitComponent.
return True
return False