"""Functions for plotting the dynamic shapes dependency graph."""
import networkx as nx
from openmdao.core.problem import Problem
from openmdao.utils.mpi import MPI
from openmdao.utils.file_utils import _load_and_exec
import openmdao.utils.hooks as hooks
from openmdao.utils.general_utils import common_subpath
def _view_dyn_shapes_setup_parser(parser):
"""
Set up the openmdao subparser for the 'openmdao view_dyn_shapes' command.
Parameters
----------
parser : argparse subparser
The parser we're adding options to.
"""
parser.add_argument('file', nargs=1, help='Python file containing the model.')
parser.add_argument('-p', '--problem', action='store', dest='problem', help='Problem name')
parser.add_argument('-o', default='shape_dep_graph.png', action='store', dest='outfile',
help='plot file.')
parser.add_argument('-t', '--title', action='store', dest='title', help='title of plot.')
parser.add_argument('--no_display', action='store_true', dest='no_display',
help="don't display the plot.")
def _view_dyn_shapes_cmd(options, user_args):
"""
Return the post_setup hook function for 'openmdao view_dyn_shapes'.
Parameters
----------
options : argparse Namespace
Command line options.
user_args : list of str
Args to be passed to the user script.
"""
def _view_shape_graph(model):
view_dyn_shapes(model, outfile=options.outfile, show=not options.no_display,
title=options.title)
def _set_dyn_hook(prob):
# we can't wait until the end of Problem.setup because we'll die in _setup_sizes
# if there were any unresolved dynamic shapes, so put the hook immediately after
# _setup_dynamic_shapes. inst_id is None here because no system's pathname will
# have been set at the time this hook is triggered.
hooks._register_hook('_setup_dynamic_shapes', class_name='Group', inst_id=None,
post=_view_shape_graph, exit=True)
hooks._setup_hooks(prob.model)
# register the hooks
hooks._register_hook('setup', 'Problem', pre=_set_dyn_hook, ncalls=1)
_load_and_exec(options.file[0], user_args)
[docs]def view_dyn_shapes(root, outfile='shape_dep_graph.png', show=True, title=None):
"""
Generate a plot file containing the dynamic shape dependency graph.
Optionally displays the plot.
Parameters
----------
root : System or Problem
The top level system or Problem.
outfile : str, optional
The name of the plot file. Defaults to 'shape_dep_graph.png'.
show : bool, optional
If True, display the plot. Defaults to True.
title : str, optional
Sets the title of the plot.
"""
if MPI and MPI.COMM_WORLD.rank != 0:
return
if isinstance(root, Problem):
system = root.model
else:
system = root
if root.pathname != '':
raise RuntimeError("view_dyn_shapes cannot be called on a subsystem of the model. "
"Call it with the Problem or the model.")
try:
import matplotlib.pyplot as plt
except ImportError:
raise RuntimeError("The view_dyn_shapes command requires matplotlib.")
graph = system._shapes_graph
if graph is None:
raise RuntimeError("Can't plot dynamic shape dependency graph because it hasn't been "
"computed yet. view_dyn_shapes must be called after problem setup().")
if graph.order() == 0:
print("The model has no dynamically shaped variables.")
return
if title is None:
# keep the names from being super long by removing any common subpath
common = common_subpath(graph.nodes())
if common:
title = f"Dynamic shape dependencies in group '{common}'"
common_idx = len(common) + 1 if common else 0
else:
title = "Dynamic shape dependencies"
common_idx = 0
abs2meta = system._var_allprocs_abs2meta
dyn_names = ['shape_by_conn', 'compute_shape', 'copy_shape']
# label variables with known shape at the start of the algorithm in green, unknowns in red.
# prepend the shape onto the variable name
node_colors = []
node_labels = {}
for n in graph:
meta = abs2meta['input'][n] if n in abs2meta['input'] else abs2meta['output'][n]
shape = meta['shape']
if shape is None:
shape = '?'
node_colors.append('red')
else:
for shname in dyn_names:
if meta.get(shname, False):
node_colors.append('blue')
break
else:
node_colors.append('green')
node_labels[n] = f"{shape}: {n[common_idx:]}"
nx.draw_networkx(graph, with_labels=True, node_color=node_colors, labels=node_labels)
plt.axis('off') # turn of axis
plt.title(title)
plt.savefig(outfile)
if show:
plt.show()
# TODO: add a legend
# TODO: use a better graph plotting lib, maybe D3 or something else, to get better layout