from difflib import get_close_matches
import networkx as nx
from networkx import dfs_edges, dfs_postorder_nodes
import numpy as np
from numpy import isscalar, reshape
from numbers import Number
from collections import deque
from copy import deepcopy
import textwrap
from itertools import chain
import webbrowser
import threading
import time
from http.server import HTTPServer
from openmdao.utils.graph_utils import networkx_to_dot, dot_to_html
from openmdao.utils.general_utils import common_subpath, is_undefined, truncate_str, \
all_ancestors, collect_error, collect_errors
from openmdao.utils.array_utils import array_connection_compatible, shape_to_len, \
get_global_dist_shape, evenly_distrib_idxs, array_hash
from openmdao.utils.units import is_compatible
from openmdao.utils.units import unit_conversion
from openmdao.utils.indexer import indexer, Indexer, idx_list_to_index_array
from openmdao.utils.om_warnings import issue_warning, UnitsWarning
from openmdao.utils.units import _is_unitless, has_val_mismatch
from openmdao.visualization.tables.table_builder import generate_table
from openmdao.core.constants import INT_DTYPE
from openmdao.utils.mpi import MPI
# use hex colors here because the using english names was sometimes causing failure to show
# proper colors in the help dialog.
GRAPH_COLORS = {
'input': 'peachpuff3',
'output': 'skyblue3',
'highlight': '#66ff00',
'ambiguous': '#FF0800',
'boundary': '#D3D3D3',
}
_shape_func_map = {
# (src distributed, tgt distributed, fwd)
(False, False, True): 'serial2serial_fwd_shape',
(False, False, False): 'serial2serial_rev_shape',
(False, True, True): 'serial2dist_fwd_shape',
(False, True, False): 'serial2dist_rev_shape',
(True, False, True): 'dist2serial_fwd_shape',
(True, False, False): 'dist2serial_rev_shape',
(True, True, True): 'dist2dist_fwd_shape',
(True, True, False): 'dist2dist_rev_shape',
}
_continuous_copy_meta = ['val', 'units', 'shape', 'discrete', 'remote', 'distributed',
'global_shape']
_discrete_copy_meta = ['val', 'discrete', 'remote']
def _strip_np(shape):
"""
Strip np.int64 from a shape string.
Parameters
----------
shape : tuple or None
Shape to strip.
Returns
-------
tuple or None
Stripped shape.
"""
if shape is None:
return shape
# so displayed shapes won't have np.int64 in them
ret = []
for item in shape:
ret.append(int(item))
return tuple(ret)
[docs]
def are_compatible_values(a, b, discrete, src_indices=None):
"""
Check compatibility of values a and b.
Parameters
----------
a : any
First value to compare.
b : any
Second value to compare.
discrete : bool
Whether the values are discrete.
src_indices : Indexer or None
The src_indices applied to a, if any.
Returns
-------
bool
True if a and b are compatible, False otherwise.
"""
if discrete:
return (isinstance(b, type(a)) or isinstance(a, type(b)))
a = np.atleast_1d(a)
b = np.atleast_1d(b)
try:
np.promote_types(a.dtype, b.dtype)
except TypeError:
return False
if src_indices is None:
return array_connection_compatible(a.shape, b.shape)
return array_connection_compatible(b.shape, src_indices.indexed_src_shape)
[docs]
class ConnError(ValueError):
"""
An error raised when a connection is incompatible.
Parameters
----------
msg : str
The error message.
ident : hashable object
Identifier of the object responsible for issuing the error.
Attributes
----------
ident : hashable object
Identifier of the object responsible for issuing the error.
"""
[docs]
def __init__(self, msg, ident=None):
super().__init__(msg)
self.ident = ident
flag_type = np.uint8
base = flag_type(1)
DISCRETE = base << 0
REMOTE = base << 1
REQUIRE_CONNECTION = base << 2
SHAPE_BY_CONN = base << 3
UNITS_BY_CONN = base << 4
AMBIGUOUS_UNITS = base << 5
AMBIGUOUS_VAL = base << 6
# NOTE: if adding more flags, make sure the flag_type is large enough to hold all the flags.
AMBIGUOUS = AMBIGUOUS_VAL | AMBIGUOUS_UNITS
NONE_UP_VAL = REMOTE | AMBIGUOUS_VAL
[docs]
class Defaults():
"""
Container for default values of a variable.
Attributes
----------
val : any
Default value.
units : str or None
Default units.
src_shape : tuple or None
Default source shape.
"""
__slots__ = ['val', 'units', 'src_shape']
[docs]
def __init__(self, val=None, units=None, src_shape=None):
self.val = val
self.units = units
self.src_shape = src_shape
[docs]
def __iter__(self):
"""
Iterate over the attributes.
Yields
------
any
The attribute value.
str or None
The units.
tuple or None
The source shape.
"""
yield self.val
yield self.units
yield self.src_shape
_global_to_update = ['global_shape', 'global_size']
_local_to_update = ['global_shape', 'global_size']
[docs]
class NodeAttrs():
"""
Container for per-node variable metadata used by the connection graph.
Attributes
----------
pathname : str or None
The pathname of System adding the node.
rel_name : str or None
The name relative to the System adding the node.
_val : any or None
The value of the node.
_shape : tuple or None
The shape of the node.
_global_shape : tuple or None
The global shape of the node.
_units : str or None
The units of the node.
_src_inds_list : list or None
The source indices list for the node.
_meta : dict or None
The global metadata of the node. This is always None for promoted variables nodees.
_locmeta : dict or None
The local metadata of the node. This is always None for promoted variables nodes.
copy_shape : tuple or None
The name of the variable to copy shape from.
compute_shape : tuple or None
The function to compute the shape of the node.
copy_units : str or None
The name of the variable to copy units from.
compute_units : str or None
The function to compute the units of the node.
distributed : bool or None
Whether the node is distributed.
defaults : Defaults
The default values for the node.
flags : int
The flags for the node.
"""
__slots__ = ('pathname', 'rel_name', '_val', '_shape', '_global_shape', '_units', 'defaults',
'_src_inds_list', 'flags', '_meta', '_locmeta', 'copy_shape', 'compute_shape',
'copy_units', 'compute_units', 'distributed')
[docs]
def __init__(self, pathname=None, rel_name=None):
"""
Initialize NodeAttrs with all attributes set to None or default values.
"""
self.pathname = pathname
self.rel_name = rel_name
self.defaults = Defaults(None, None, None)
self._meta = None
self._locmeta = None
self.flags = flag_type(0)
self._src_inds_list = []
self._val = None
self._shape = None
self._global_shape = None
self._units = None
self.copy_shape = None
self.compute_shape = None
self.copy_units = None
self.compute_units = None
self.distributed = None # 3 states: None (unknown), False or True
def __getattr__(self, key):
"""
Get missing attribute.
Parameters
----------
key : str
The attribute name.
Returns
-------
None
Always returns None for missing attributes.
"""
return None
[docs]
def __getitem__(self, key):
"""
Get attribute by key.
Parameters
----------
key : str
The attribute name.
Returns
-------
any
The value of the attribute.
"""
return getattr(self, key)
[docs]
def __setitem__(self, key, value):
"""
Set attribute by key.
Parameters
----------
key : str
The attribute name.
value : any
The value to set.
"""
setattr(self, key, value)
[docs]
def get(self, key, default=None):
"""
Get attribute by key with a default value if not found.
Parameters
----------
key : str
The attribute name.
default : any, optional
The default value to return if the attribute is not found.
Returns
-------
any
The value of the attribute or the default value.
"""
return getattr(self, key, default)
[docs]
def update(self, kwargs):
"""
Update attributes from a dictionary.
This method is used by networkx to populate initial node attributes.
Parameters
----------
kwargs : dict
Dictionary of attribute names and values to set.
"""
# networkx uses this to populate the initial node attributes.
for key, value in kwargs.items():
setattr(self, key, value)
[docs]
def msgname(self):
"""
Get the full message name combining pathname and relative name.
Returns
-------
str
The full name if pathname exists, otherwise just the relative name.
"""
if self.pathname:
return f"{self.pathname}.{self.rel_name}"
return self.rel_name
def __repr__(self):
"""
Return a string representation of the NodeAttrs object.
Returns
-------
str
A formatted table representation of non-None attributes.
"""
rows = [[key, getattr(self, key)] for key in self.__slots__
if getattr(self, key) is not None and key != 'meta']
table = generate_table(rows, tablefmt='plain')
return f"{self.__class__.__name__}:\n{table}"
@property
def meta(self):
return self._meta
@property
def locmeta(self):
return self._locmeta
@property
def src_inds_list(self):
return self._src_inds_list
@src_inds_list.setter
def src_inds_list(self, value):
self._src_inds_list = value
if self._locmeta is not None:
self._locmeta['src_inds_list'] = value
@property
def val(self):
return self._val
@val.setter
def val(self, value):
if self.flags & DISCRETE:
self._val = value
elif value is not None:
if self._shape is None:
self._val = value
if np.ndim(value) > 0:
# setting 'shape' here will update the meta and locmeta if appropriate
self.shape = np.shape(value)
else:
if self._shape == ():
if isscalar(value):
self._val = value
else:
self._val = value.item()
elif self._val is None:
if np.isscalar(value):
self._val = np.full(self._shape, value, dtype=float)
else:
self._val = reshape(value, self._shape)
else:
self._val[:] = reshape(value, self._shape)
@property
def shape(self):
return self._shape
@shape.setter
def shape(self, shape):
if self._shape is None and shape is not None:
shape = _strip_np(shape)
self._shape = shape
if self._shape != ():
if self._val is not None and np.ndim(self._val) == 0:
# if val is a scalar, reshape it to the new shape
self._val = np.full(shape, self._val)
if self.distributed is False:
self.global_shape = shape
if self._meta is not None:
self._meta['shape'] = shape
size = shape_to_len(shape)
self._meta['size'] = size
if self._locmeta is not None:
if self._val is None:
self._val = np.ones(shape)
self._locmeta['shape'] = shape
self._locmeta['size'] = size
@property
def size(self):
return shape_to_len(self._shape)
@property
def global_shape(self):
if self.distributed:
return self._global_shape
return self._shape
@global_shape.setter
def global_shape(self, global_shape):
self._global_shape = global_shape
if self._meta is not None:
self._meta['global_shape'] = global_shape
self._meta['global_size'] = self.global_size
if self._locmeta is not None:
self._locmeta['global_shape'] = global_shape
self._locmeta['global_size'] = self.global_size
@property
def global_size(self):
return shape_to_len(self.global_shape)
@property
def units(self):
return self._units
@units.setter
def units(self, units):
if units is not None and self._units is None:
if self.discrete:
raise ValueError("Cannot set units for discrete variable "
f"'{self.msgname()}'.")
self._units = units
if self._meta is not None:
self._meta['units'] = units
if self._locmeta is not None:
self._locmeta['units'] = units
@property
def discrete(self):
return bool(self.flags & DISCRETE)
@discrete.setter
def discrete(self, value):
if value:
self.flags |= DISCRETE
else:
self.flags &= ~DISCRETE
@property
def remote(self):
return bool(self.flags & REMOTE)
@remote.setter
def remote(self, value):
if value:
self.flags |= REMOTE
else:
self.flags &= ~REMOTE
@property
def require_connection(self):
return bool(self.flags & REQUIRE_CONNECTION)
@require_connection.setter
def require_connection(self, value):
if value:
self.flags |= REQUIRE_CONNECTION
else:
self.flags &= ~REQUIRE_CONNECTION
@property
def shape_by_conn(self):
return bool(self.flags & SHAPE_BY_CONN)
@shape_by_conn.setter
def shape_by_conn(self, value):
# if shape defaults are are set, then this node's shape is known
if self.defaults.src_shape is not None:
return
if value:
self.flags |= SHAPE_BY_CONN
else:
self.flags &= ~SHAPE_BY_CONN
@property
def units_by_conn(self):
return bool(self.flags & UNITS_BY_CONN)
@units_by_conn.setter
def units_by_conn(self, value):
# if unit defaults are are set, then this node's units are known
if self.defaults.units is not None:
return
if value:
self.flags |= UNITS_BY_CONN
else:
self.flags &= ~UNITS_BY_CONN
@property
def ambiguous_units(self):
return bool(self.flags & AMBIGUOUS_UNITS)
@ambiguous_units.setter
def ambiguous_units(self, value):
if value:
self.flags |= AMBIGUOUS_UNITS
else:
self.flags &= ~AMBIGUOUS_UNITS
@property
def ambiguous_val(self):
return bool(self.flags & AMBIGUOUS_VAL)
@ambiguous_val.setter
def ambiguous_val(self, value):
if value:
self.flags |= AMBIGUOUS_VAL
else:
self.flags &= ~AMBIGUOUS_VAL
@property
def ambiguous(self):
return bool(self.flags & AMBIGUOUS)
@property
def dyn_shape(self):
return self.shape_by_conn or self.copy_shape is not None or self.compute_shape is not None
@property
def dyn_units(self):
return self.units_by_conn or self.copy_units is not None or self.compute_units is not None
@property
def dynamic(self):
return self.dyn_shape or self.dyn_units
[docs]
def units_from_child(self):
"""
Get units, preferring defaults if available.
Returns
-------
str or None
The units from defaults if available, None if ambiguous, otherwise the node's units.
"""
if self.defaults.units is not None:
return self.defaults.units
if self.ambiguous_units:
return None
return self.units
[docs]
def shape_from_child(self, node_meta, src_indices):
"""
Get shape for a child node, accounting for defaults and src_indices.
Parameters
----------
node_meta : NodeAttrs
The metadata of the child node.
src_indices : Indexer or None
The source indices applied to this node.
Returns
-------
tuple or None
The shape from defaults if available, None if src_indices are partial or
if this node is distributed but the child is not, otherwise this node's shape.
"""
if self.defaults.src_shape is not None:
return self.defaults.src_shape
if src_indices is not None and not src_indices.is_full_slice():
# can't determine the shape of the parent if src_indices are present
return None
if self.distributed and not node_meta.distributed:
return None
return self.shape
[docs]
def val_from_child(self, node_meta, src_indices):
"""
Get value for a child node, accounting for defaults and compatibility.
Parameters
----------
node_meta : NodeAttrs
The metadata of the child node.
src_indices : Indexer or None
The source indices applied to this node.
Returns
-------
any or None
The value from defaults if available, None if ambiguous, incompatible,
or if this node is distributed but the child is not, otherwise this node's value.
"""
if self.flags & NONE_UP_VAL or src_indices is not None:
return None
if self.defaults.val is not None:
return self.defaults.val
if self.ambiguous_val:
return None
if self.defaults.src_shape is not None and self.val is not None:
# this covers the situation where they set the src_shape but not the default val
# and the val in the node is not compatible with the src_shape being passed to the
# parent.
if np.shape(self.val) != self.defaults.src_shape:
return None
if self.distributed and not node_meta.distributed:
return None
return self.val
[docs]
def dist_shapes(self, comm):
"""
Gather shapes from all processes in a communicator.
Parameters
----------
comm : MPI communicator
The MPI communicator to gather from.
Returns
-------
list
A list of shapes from all processes if MPI size > 1, otherwise a list
containing only this node's shape.
"""
if comm.size > 1:
return comm.allgather(self.shape)
else:
return [self.shape]
[docs]
def as_dict(self):
"""
Return attributes as a dictionary, excluding internal metadata.
Returns
-------
dict
A dictionary of non-None attributes excluding _meta, _locmeta, and _val.
"""
skip = {'_meta', '_locmeta', '_val'}
ret = {}
for name in self.__slots__:
if name not in skip:
metaval = getattr(self, name)
if metaval is not None:
ret[name] = metaval
return ret
[docs]
class AllConnGraph(nx.DiGraph):
"""
A graph for all connection info. Covers manual, implicit, and all promotions.
Every connection in the graph forms a tree structure with an absolute output name at its
root and all connected absolute input names as the leaf nodes.
Node keys are tuples of the form (io, name), where io is either 'i' or 'o', and name is the
name (either promoted or absolute) of a variable in the model.
src_indices are stored in the edges between nodes.
Attributes
----------
_mult_inconn_nodes : set
A set of nodes that have multiple incoming connections.
_input_input_conns : set
A set of nodes that have input to input connections.
_first_pass : bool
True if we're in the first pass of node data updates.
_required_conns : set
A set of input nodes that have required connections, direct or indirect, to an output node.
_resolved : set
A set of nodes that have been resolved and so can be skipped in the second pass.
_has_dynamic_shapes : bool
Whether the graph has dynamic shape behavior.
_has_dynamic_units : bool
Whether the graph has dynamic units behavior.
_dist_shapes : list or None
A list of shapes from all processes if MPI size > 1.
_dist_sizes : list or None
A list of sizes from all processes if MPI size > 1.
_dist_nodes : set
A set of nodes that are distributed.
_problem_meta : dict or None
The metadata of the problem.
_bad_conns : set or None
A set of nodes that have bad connections.
msginfo : str or None
The message information for the top level System.
comm : MPI communicator or None
The MPI communicator for the model.
_var_existence : dict or None
A dictionary of variable existence from the model.
_var_allprocs_abs2meta : dict or None
A dictionary of variable all processes absolute metadata from the model.
_var_allprocs_discrete : dict or None
A dictionary of variable all processes discrete from the model.
_var_allprocs_abs2idx : dict or None
A dictionary of variable all processes absolute index from the model.
_var_abs2meta : dict or None
A dictionary of variable absolute metadata from the model.
_sync_auto_ivcs : dict
A dictionary of auto_ivcs that require sync when setting intial values.
_dangling_prom_inputs : set
A set of dangling promoted inputs.
"""
[docs]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._mult_inconn_nodes = set()
self._input_input_conns = set()
self._first_pass = True
self._required_conns = set()
self._resolved = set()
self._has_dynamic_shapes = False
self._has_dynamic_units = False
self._dist_shapes = None
self._dist_sizes = None
self._dist_nodes = set()
self._problem_meta = None
self.msginfo = None
self.comm = None
self._var_existence = None
self._var_allprocs_abs2meta = None
self._var_allprocs_discrete = None
self._var_allprocs_abs2idx = None
self._var_abs2meta = None
self._bad_conns = set()
self._sync_auto_ivcs = {} # auto_ivcs that require sync when setting intial values
self._dangling_prom_inputs = set()
def _collect_error(self, msg, exc_type=None, tback=None, ident=None):
"""
Save an error message to raise as an exception later.
Parameters
----------
msg : str
The connection error message to be saved.
exc_type : class or None
The type of exception to be raised if this error is the only one collected.
tback : traceback or None
The traceback of a caught exception.
ident : int
Identifier of the object responsible for issuing the error.
"""
collect_error(msg, self._get_saved_errors(), exc_type, tback, ident, msginfo=self.msginfo)
def _get_saved_errors(self):
"""Get saved errors.
Returns
-------
any
Returned value.
"""
if self._problem_meta is None:
return None
return self._problem_meta['saved_errors']
[docs]
def find_node(self, pathname, varname, io=None):
"""
Find a node in the graph.
Parameters
----------
pathname : str
The current scoping system pathname.
varname : str
The variable name to find.
io : str
The io type of the variable
Returns
-------
tuple of the form (io, name), where io is either 'i' or 'o'.
The node found.
"""
if pathname:
prefix = pathname + '.'
if varname.startswith(prefix):
name = varname
else:
name = pathname + '.' + varname
else:
name = varname
if io is None:
node = ('o', name)
if node not in self:
node = ('i', name)
else:
node = (io[0], name)
if node in self:
return node
msg = f"{pathname}: Variable '{varname}' not found."
guesses = sorted(set(get_close_matches(name, [n[1] for n, d in self.nodes(data=True)
if d['attrs'].pathname==pathname],
n=3, cutoff=0.15)))
guesses = [g for g in guesses if not g.startswith('_auto_ivc.')]
if guesses:
msg = f"{msg} Perhaps you meant one of the following variables: {guesses}."
raise KeyError(msg)
[docs]
def top_name(self, node):
"""
Return the variable name relative to the top level group.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node to get the top name for.
Returns
-------
str or None
The variable name relative to the top level group or None if the node is not connected.
"""
if node[0] == 'i':
root = self.input_root(node)
return root[1] if root else None
else:
if self.out_degree(node) == 0:
return node[1]
for _, v in dfs_edges(self, node):
if v[0] == 'o' and self.out_degree(v) == 0:
return v[1]
[docs]
def base_error(self, msg, src, tgt, src_indices=None):
"""
Return the error message for a connection error.
Parameters
----------
msg : str
msg.
src : tuple of the form ('i' or 'o', name)
source.
tgt : tuple of the form ('i' or 'o', name)
target.
src_indices : Indexer or None
source indices.
Returns
-------
str
Returned value.
"""
edge = (src, tgt)
edge_meta = self.edges[edge]
typ = edge_meta.get('type', None)
type_map = {None: 'promote', 'manual': 'connect', 'implicit': 'implicitly connect'}
typestr = type_map[typ]
if typestr == 'promote' and src[0] == 'i':
src, tgt = tgt, src # promotion for inputs is up the tree
fromto = f"'{src[1]}' to '{tgt[1]}'"
indstr = ''
if src_indices is not False: # False means don't include src_indices in the error message
src_indices = edge_meta.get('src_indices', None)
if src_indices is not None:
indstr = f" when applying index {truncate_str(src_indices, max_len=50)}"
return f"Can't {typestr} {fromto}{indstr}: {msg}"
[docs]
def shape_error(self, src, tgt, src_shape, tgt_shape):
"""
Return an error message for a shape incompatibility.
Parameters
----------
src : tuple of the form ('i' or 'o', name)
source.
tgt : tuple of the form ('i' or 'o', name)
target.
src_shape : tuple
source shape.
tgt_shape : tuple
target shape.
Returns
-------
str
Returned value.
"""
return self.base_error(f"shape {src_shape} of '{src[1]}' is incompatible with shape "
f"{tgt_shape} of '{tgt[1]}'.", src, tgt)
[docs]
def value_error(self, going_up, src, tgt, src_val, tgt_val):
"""
Return an error message for a value incompatibility.
Parameters
----------
going_up : bool
going up.
src : tuple of the form ('i' or 'o', name)
source.
tgt : tuple of the form ('i' or 'o', name)
target.
src_val : any
source val.
tgt_val : any
target val.
Returns
-------
str
Returned value.
"""
if going_up:
src, tgt = tgt, src
src_val, tgt_val = tgt_val, src_val
if not self.nodes[tgt]['attrs'].discrete:
sshp = np.shape(src_val)
tshp = np.shape(tgt_val)
if sshp != tshp:
return self.shape_error(src, tgt, sshp, tshp)
return self.base_error(f"value {truncate_str(src_val, max_len=50)} of '{src[1]}' is "
f"incompatible with value {truncate_str(tgt_val, max_len=50)} of "
f"'{tgt[1]}'.", src, tgt)
[docs]
def units_error(self, going_up, src, tgt, src_units, tgt_units):
"""
Return an error message for a units incompatibility.
Parameters
----------
going_up : bool
going up.
src : tuple of the form ('i' or 'o', name)
source.
tgt : tuple of the form ('i' or 'o', name)
target.
src_units : str
source units.
tgt_units : str
target units.
Returns
-------
str
Returned value.
"""
if going_up:
src, tgt = tgt, src
src_units, tgt_units = tgt_units, src_units
return self.base_error(f"units '{src_units}' of '{src[1]}' are incompatible with units "
f"'{tgt_units}' of '{tgt[1]}'.", src, tgt, src_indices=False)
[docs]
def handle_error(self, going_up, src, tgt, exc):
"""
Given an exception either raise it or save it for later.
Parameters
----------
going_up : bool
going up.
src : tuple of the form ('i' or 'o', name)
source.
tgt : tuple of the form ('i' or 'o', name)
target.
exc : Exception
The exception to handle.
"""
if going_up:
src, tgt = tgt, src
excstr = str(exc)
if isinstance(exc, ConnError):
ident = exc.ident
else:
ident = frozenset((src, tgt))
edge_meta = self.edges[src, tgt]
src_indices = edge_meta.get('src_indices', None)
excstr = self.base_error(msg=excstr, src=src, tgt=tgt, src_indices=src_indices)
self._collect_error(f"{self.msginfo}: {excstr}", tback=exc.__traceback__, ident=ident)
[docs]
def msgname(self, node):
"""
Get full name of node, including absolute names if they differ from the promoted name.
Parameters
----------
node : tuple or str
The node to get the full name of. Tuple of the form (io, name), where io is either
'i' or 'o', or a variable name.
Returns
-------
str
The full name of the node.
"""
if node not in self: # might be a var name instead of a node
orig = node
node = ('i', node)
if node not in self:
node = ('o', node)
if node not in self:
raise ValueError(f"Node '{orig}' not found in the graph.")
names = self.absnames(node)
names = sorted(names)
if len(names) == 1:
names = names[0]
if node[1] == names:
return names
return f'{node[1]} ({names})'
[docs]
def startswith(self, prefix, node):
"""
Check if the node name starts with the prefix.
Parameters
----------
prefix : str
The prefix to check.
node : tuple of the form ('i' or 'o', name)
The node to check.
Returns
-------
bool
True if the node name starts with the prefix, False otherwise.
"""
if prefix:
return node[1].startswith(prefix)
return True
[docs]
def sync_auto_ivcs(self, model):
"""
Synchronize auto_ivc values that require it.
Parameters
----------
model : System
The system containing the auto_ivcs to be synchronized.
"""
has_vectors = model.has_vectors()
all_sync_auto_ivcs = {k: v for k, v in self._sync_auto_ivcs.items() if v is not None}
for allsync in model.comm.allgather(all_sync_auto_ivcs):
for name, owner in allsync.items():
if name not in all_sync_auto_ivcs:
if owner is not None:
all_sync_auto_ivcs[name] = owner
elif owner is not None and owner < all_sync_auto_ivcs[name]:
all_sync_auto_ivcs[name] = owner
for name, owner in sorted(all_sync_auto_ivcs.items()):
node = self.find_node('', name, io='o')
node_meta = self.nodes[node]['attrs']
if owner == model.comm.rank:
if has_vectors:
val = model._outputs[name]
else:
val = node_meta.val
model.comm.bcast(val, root=owner)
else:
val = model.comm.bcast(None, root=owner)
if has_vectors:
model._outputs[name] = val
else:
node_meta.val = val
# reset the flag so we don't sync again unless the user changes the value again
self._sync_auto_ivcs[name] = None
[docs]
def get_val_from_src(self, system, name, units=None, indices=None, get_remote=False, rank=None,
vec_name='nonlinear', kind=None, flat=False, use_vec=False, src_node=None):
"""Get val from source.
Parameters
----------
system : System
The System requesting the value.
name : str
The name of the variable to get the value of.
units : str or None
The units to convert to before returning the value.
indices : int or iter of ints or None
The indices or slice to return.
indices.
get_remote : bool
If True, retrieve the value even if it is on a remote process.
rank : int or None
If not None, only gather the value to this rank.
vec_name : str
The name of the vector to use.
kind : str or None
The kind of variable to get the value of.
flat : bool
If True, return the flattened version of the value.
use_vec : bool
If True, use the vector to get the value.
src_node : tuple of the form ('i' or 'o', name) or None
The source node.
Returns
-------
numpy.ndarray
Returned value.
"""
node = self.find_node(system.pathname, name)
node_meta = self.nodes[node]['attrs']
if src_node is None:
src_node = self.get_root(node)
src_meta = self.nodes[src_node]['attrs']
src_inds_list = node_meta.src_inds_list
if not get_remote and self.comm.size > 1:
if src_meta.distributed and not node_meta.distributed:
raise RuntimeError(f"{self.msginfo}: Non-distributed variable '{node[1]}' has "
f"a distributed source, '{src_node[1]}', so you must retrieve "
"its value using 'get_remote=True'.")
elif node_meta.distributed and src_inds_list:
if src_meta.distributed:
model = system._problem_meta['model_ref']()
src_inds_list = self.inds_into_local_distrib(model, src_node, node,
src_inds_list)
elif src_meta.remote:
err = True
if self.comm.allreduce(err, op=MPI.LOR):
raise RuntimeError(f"{self.msginfo}: Can't retrieve distributed variable "
f"'{node[1]}' because its src_indices reference "
"entries from other processes. You can retrieve values "
"from all processes using "
"`get_val(<name>, get_remote=True)`.")
if use_vec:
val = system._abs_get_val(src_node[1], get_remote, rank, vec_name, kind, flat,
from_root=True)
else:
val = src_meta.val
if is_undefined(val):
raise ValueError(f"{system.msginfo}: Variable '{self.msgname(src_node)}' has not "
"been initialized.")
try:
val = self.convert_get(node, val, src_meta.units, node_meta.units,
src_inds_list, units, indices,
get_remote=get_remote)
except Exception as err:
raise ValueError(f"{system.msginfo}: Can't get value of '{node[1]}': {str(err)}")
if node[0] == 'i' and get_remote and node_meta.distributed and self.comm.size > 1:
# gather parts of the distrib input value from all procs
full_val = np.zeros(shape_to_len(node_meta.global_shape))
sizes = np.asarray(self.get_dist_sizes(node), dtype=INT_DTYPE)
offsets = np.zeros(sizes.size, dtype=INT_DTYPE)
offsets[1:] = np.cumsum(sizes[:-1])
self.comm.Allgatherv(val.ravel(), [full_val, sizes, offsets, MPI.DOUBLE])
val = np.reshape(full_val, node_meta.global_shape)
if flat and not node_meta.discrete:
val = val.ravel()
return val
[docs]
def get_val(self, system, name, units=None, indices=None, get_remote=False, rank=None,
vec_name='nonlinear', kind=None, flat=False, from_src=True):
"""
Return the value of a variable.
Parameters
----------
system : System
The System requesting the value.
name : str
The name of the variable to get the value of.
units : str or None
The units to convert to before returning the value.
indices : int or iter of ints or None
The indices or slice to return.
get_remote : bool
If True, retrieve the value even if it is on a remote process.
rank : int or None
If not None, only gather the value to this rank.
vec_name : str
The name of the vector to use.
kind : str or None
The kind of variable to get the value of.
flat : bool
If True, return the flattened version of the value.
from_src : bool
If True, retrieve the value of an input variable from its connected source.
Returns
-------
numpy.ndarray
Returned value.
"""
if indices is not None and not isinstance(indices, Indexer):
indices = indexer(indices, flat_src=flat)
node = self.find_node(system.pathname, name)
node_meta = self.nodes[node]['attrs']
# ambiguous units aren't fatal during setup, but if we're getting a specific promoted
# input that has ambiguous units, it becomes fatal, so we need to check that here.
if node_meta.ambiguous_units:
raise ValueError(self.ambig_units_msg(node))
if from_src or node[0] == 'o':
return self.get_val_from_src(system, name, units=units, indices=indices,
get_remote=get_remote, rank=rank, vec_name=vec_name,
kind=kind, flat=flat, use_vec=system.has_vectors())
# since from_src is False, we're getting a specific input
# (must use absolute name or have only a single leaf node and no src_indices between the
# promoted input node and the leaf node)
leaves = list(self.leaf_input_iter(node))
if len(leaves) > 1:
raise ValueError(
f"{system.msginfo}: Promoted variable '{name}' refers to multiple "
"input variables so the choice of input is ambiguous. Either "
"use the absolute name of the input or set 'from_src=True' to "
"retrieve the value from the connected output.")
leaf_meta = self.nodes[leaves[0]]['attrs']
if system.has_vectors():
model = system._problem_meta['model_ref']()
if model._resolver.is_prom(node[1], 'input' if node[0] == 'i' else 'output'):
abs_name = model._resolver.prom2abs(node[1])
else:
abs_name = node[1]
val = system._abs_get_val(abs_name, get_remote, rank, vec_name, kind, flat,
from_root=True)
else:
val = leaf_meta.val
if is_undefined(val):
raise ValueError(f"{system.msginfo}: Variable '{self.msgname(node)}' has not "
"been initialized.")
if node != leaves[0]:
if leaf_meta.src_inds_list:
src_inds_list = leaf_meta.src_inds_list[len(node_meta.src_inds_list):]
if src_inds_list:
raise RuntimeError(f"Can't get the value of promoted input '{node[1]}' when "
"from_src is False because src_indices exist between the "
f"promoted input and the actual input '{leaves[0][1]}'.")
try:
val = self.convert_get(node, val, node_meta.units, leaf_meta.units, (), units, indices,
get_remote=get_remote)
except Exception as err:
raise ValueError(f"{system.msginfo}: Can't get value of '{node[1]}': {str(err)}")
if flat and not node_meta.discrete:
val = val.ravel()
return val
[docs]
def inds_into_local_distrib(self, model, src_node, tgt_node, inds):
"""
Convert indices into distributed indices and verify that they only reference local entries.
Parameters
----------
model : Model
The model.
src_node : tuple of the form ('i' or 'o', name)
The source node.
tgt_node : tuple of the form ('i' or 'o', name)
The target node.
inds : list of Indexers
The indices to convert.
Returns
-------
list of Indexers
The converted indices.
"""
if inds and self.comm.size > 1 and self.nodes[src_node]['attrs'].distributed:
src = src_node[1]
# TODO: this could be expensive. Try to make it more efficient.
src_indices = idx_list_to_index_array(inds)
ssizes = model._var_sizes['output']
sidx = model._var_allprocs_abs2idx[src]
ssize = ssizes[self.comm.rank, sidx]
start = np.sum(ssizes[:self.comm.rank, sidx])
end = start + ssize
if np.any(src_indices < start) or np.any(src_indices >= end):
err = True
else:
err = False
if self.comm.allreduce(err, op=MPI.LOR):
raise RuntimeError(f"{model.msginfo}: Can't retrieve distributed variable "
f"'{tgt_node[1]}' because its src_indices reference "
"entries from other processes. You can retrieve values "
"from all processes using "
"`get_val(<name>, get_remote=True)`.")
if start > 0:
src_indices = src_indices - start
inds = [indexer(src_indices)]
return inds
[docs]
def set_val(self, system, name, val, units=None, indices=None):
"""
Set the value of a variable.
Parameters
----------
system : System
The System setting the value.
system.
name : str
The name of the variable to set the value of.
val : any
The value to set.
units : str or None
The units to convert to before setting the value.
indices : int or iter of ints or None
The indices or slice to set.
"""
node = self.find_node(system.pathname, name)
node_meta = self.nodes[node]['attrs']
nodes = self.nodes
src_node = self.get_root(node)
src_meta = nodes[src_node]['attrs']
src = src_node[1]
model = system._problem_meta['model_ref']()
if src_meta.discrete:
if system.has_vectors():
if src in model._discrete_outputs:
model._discrete_outputs[src] = val
if node[0] == 'i':
for abs_in in self.absnames(node):
if abs_in in model._discrete_inputs:
model._discrete_inputs[abs_in] = val
else:
self.set_tree_val(model, src_node, val)
return
# every variable is continuous from here down
if node[0] == 'o':
tgt_units = None
tgt_inds_list = ()
else:
tgt_units = node_meta.units
tgt_inds_list = node_meta.src_inds_list
src_units = src_meta.units
if indices is None:
inds = tgt_inds_list
else:
if not isinstance(indices, Indexer):
indices = indexer(indices)
inds = list(tgt_inds_list) + [indices]
# do unit conversion on given val if needed
try:
sval = self.convert_set(val, src_units, tgt_units, (), units)
except Exception as err:
self._collect_error(f"{system.msginfo}: Can't set value of '{self.msgname(node)}': "
f"{str(err)}")
return
if node_meta.remote:
if tgt_inds_list:
issue_warning(f"{model.msginfo}: Cannot set the value of '{node[1]}':"
" Setting the value of a remote connected input with"
" src_indices is currently not supported, you must call"
" `run_model()` to have the outputs populate their"
" corresponding inputs.")
if src_meta.remote:
return
if model.has_vectors():
srcval = model._abs_get_val(src, get_remote=False)
if inds and node[0] == 'i' and src_meta.distributed:
inds = self.inds_into_local_distrib(model, src_node, node, inds)
if np.ndim(srcval) > 0:
self.set_subarray(srcval, inds, sval, node)
else:
srcval = sval
model._outputs._abs_set_val(src, srcval)
if src in self._sync_auto_ivcs:
for leaf in self.leaf_input_iter(node):
if leaf[1] in model._vars_to_gather:
# mark this auto_ivc to be synchronized later
self._sync_auto_ivcs[src] = model._vars_to_gather[leaf[1]]
break
# also set the input if it's absolute
if node[0] == 'i' and node[1] in model._var_abs2meta['input']:
try:
tval = self.convert_set(val, tgt_units, tgt_units, (), units)
except Exception as err:
self._collect_error(f"{system.msginfo}: Can't set value of "
f"'{self.msgname(node)}': {str(err)}")
return
if indices is None:
model._inputs._abs_set_val(node[1], tval)
else:
model._inputs._abs_set_val(node[1], tval, idx=indices())
else:
srcval = src_meta.val
if srcval is not None:
if isinstance(srcval, Number):
if inds:
self._collect_error("Can't set a non-array using indices.")
return
src_meta.val = sval
srcval = src_meta.val
else:
self.set_subarray(srcval, inds, sval, node)
else:
if inds:
self._collect_error(f"Shape of '{name}' isn't known yet so you can't use "
f"indices to set it.")
return
srcval = sval
# propagate shape and value down the tree
self.set_tree_val(model, src_node, srcval)
[docs]
def set_tree_val(self, model, src_node, srcval):
"""
Set the value of a source in the tree and propagate it down the tree.
Parameters
----------
model : Model
The model.
src_node : tuple of the form ('i' or 'o', name)
The source node.
srcval : any
The starting value to set.
"""
nodes = self.nodes
src_meta = nodes[src_node]['attrs']
src_meta.val = srcval
src_dist = src_meta.distributed
has_vecs = model.has_vectors()
for leaf in self.leaf_input_iter(src_node):
tgt_meta = nodes[leaf]['attrs']
if tgt_meta.remote:
continue
if tgt_meta.discrete:
tgt_meta.val = srcval
continue
src_inds_list = tgt_meta.src_inds_list
if not (src_dist or tgt_meta.distributed): # serial --> serial
if src_inds_list:
tgt_meta.val = self.get_subarray(srcval, src_inds_list)
else:
tgt_meta.val = srcval
if has_vecs:
if tgt_meta.discrete:
pass
else:
model._inputs._abs_set_val(leaf[1], tgt_meta.val)
[docs]
def bfs_up_iter(self, node, include_self=True):
"""
Iterate up the tree from the given node.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node to start from.
include_self : bool, optional
Whether to include the given node in the iteration.
"""
mypreds = self.predecessors
if include_self:
yield node
stack = [mypreds(node)]
while stack:
for n in stack.pop():
yield n
stack.append(mypreds(n))
[docs]
def bfs_down_iter(self, node, include_self=True):
"""
Iterate down the tree from the given node.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node to start from.
include_self : bool, optional
Whether to include the given node in the iteration.
"""
mysuccs = self.successors
if include_self:
yield node
stack = [mysuccs(node)]
while stack:
for n in stack.pop():
yield n
stack.append(mysuccs(n))
[docs]
def check_add_edge(self, group, src, tgt, **kwargs):
"""
Check if an edge can be added to the graph and add it if it can.
Parameters
----------
group : Group
The group.
src : tuple of the form ('i' or 'o', name)
The source node.
tgt : tuple of the form ('i' or 'o', name)
The target node.
**kwargs : dict
The keyword arguments to add the edge.
"""
if (src, tgt) in self.edges():
return
if src not in self or tgt not in self:
raise ValueError(f"Node {src} or {tgt} not found in the graph.")
if self.in_degree(tgt) != 0:
self._mult_inconn_nodes.add(tgt)
if src[0] == 'o':
for p in self.pred[tgt]:
if p[0] == 'o':
self._bad_conns.add((src, tgt))
# sort names because sometimes the order is reversed
names = sorted([self.msgname(src), self.msgname(p)])
group._collect_error(
f"{group.msginfo}: Target '{self.msgname(tgt)}' cannot be "
f"connected to '{names[0]}' because it's already "
f"connected to '{names[1]}'.", ident=(src, tgt))
return
self.add_edge(src, tgt, **kwargs)
[docs]
def node_name(self, pathname, name, io):
"""
Return the name of a node.
Parameters
----------
pathname : str
The pathname of the scoping system.
name : str
The name of the variable.
io : str
The I/O type of the variable.
Returns
-------
tuple of the form ('i' or 'o', name)
The node name.
"""
return (io[0], pathname + '.' + name if pathname else name)
[docs]
def get_node_attrs(self, pathname, name, io):
"""
Get attrs from a node, adding the node if necessary.
"""
node = self.node_name(pathname, name, io)
if node not in self:
attrs = NodeAttrs(pathname=pathname, rel_name=name)
self.add_node(node, attrs=attrs)
else:
attrs = self.nodes[node]['attrs']
return node, attrs
[docs]
def get_path_prom(self, node):
"""
Get the system path and promoted name of a node.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node.
Returns
-------
tuple of the form (str, str)
The system pathname and promoted name of the node relative to the system.
"""
meta = self.nodes[node]['attrs']
return meta.pathname, meta.rel_name
[docs]
def add_continuous_var(self, model, name, meta, locmeta, io):
"""
Add a continuous variable to the graph.
Parameters
----------
model : Group
The model.
name : str
The name of the variable.
meta : dict or None
The metadata of the variable.
locmeta : dict or None
The local metadata of the variable.
io : str
The I/O type of the variable, either 'input' or 'output'.
"""
node, node_meta = self.get_node_attrs('', name, io[0])
node_meta.discrete = False
self.set_model_meta(model, node, meta, locmeta)
if node_meta.dyn_shape:
self._has_dynamic_shapes = True
if node_meta.dyn_units:
self._has_dynamic_units = True
if io == 'input' and meta['require_connection']:
self._required_conns.add(name)
node_meta.require_connection = True
[docs]
def add_discrete_var(self, model, name, meta, locmeta, io):
"""
Add a discrete variable to the graph.
Parameters
----------
model : Group
The model.
name : str
The name of the variable.
meta : dict or None
The metadata of the variable.
locmeta : dict or None
The local metadata of the variable.
io : str
The I/O type of the variable, either 'input' or 'output'.
"""
node, node_meta = self.get_node_attrs('', name, io[0])
node_meta.discrete = True
self.set_model_meta(model, node, meta, locmeta)
[docs]
def add_model_vars(self, model):
"""
Add variable metadata to the graph.
Parameters
----------
model : Group
The model.
"""
self.comm = model.comm
self._problem_meta = model._problem_meta
self.msginfo = model.msginfo
self._distributed_nodes = set()
for io in ['input', 'output']:
loc = model._var_abs2meta[io]
for name, meta in model._var_allprocs_abs2meta[io].items():
locmeta = loc[name] if name in loc else None
self.add_continuous_var(model, name, meta, locmeta, io)
loc = model._var_discrete[io]
for name, meta in model._var_allprocs_discrete[io].items():
locmeta = loc[name] if name in loc else None
self.add_discrete_var(model, name, meta, locmeta, io)
[docs]
def get_dist_shapes(self, node=None):
"""
Get the distributed shapes of a variable.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node.
Returns
-------
dict or list of tuples
The full distributed shapes dict or a single entry for the given node.
"""
if self._dist_shapes is None:
if self.comm.size > 1:
existence = self._get_var_existence()
# make sure we have all of the distributed vars
dshapes = {}
# at this point, _distributed_nodes only contains local vars
for gnode in self._distributed_nodes:
io, name = gnode
mode = 'output' if io == 'o' else 'input'
if existence[mode][self.comm.rank, self._var_allprocs_abs2idx[name]]:
dshapes[gnode] = self.nodes[gnode]['attrs'].shape
all_dshapes = {}
for rank, dshp in enumerate(self.comm.allgather(dshapes)):
for n, shp in dshp.items():
if n not in all_dshapes:
all_dshapes[n] = [None] * self.comm.size
all_dshapes[n][rank] = shp
self._distributed_nodes.update(all_dshapes.keys())
self._dist_shapes = all_dshapes
else:
self._dist_shapes = {}
if node is None:
return self._dist_shapes
if node in self._dist_shapes:
for s in self._dist_shapes[node]:
if s is not None:
return self._dist_shapes[node]
# if we get here, node is dynamically shaped so we didn't know the shape the first time
# around
if node in self._distributed_nodes and self.comm.size > 1:
node_meta = self.nodes[node]['attrs']
dist_shapes = self.comm.allgather(node_meta.shape)
self._dist_shapes[node] = dist_shapes
return dist_shapes
else:
raise ValueError(f"Can't get distributed shapes for variable '{node[1]}' because it is "
"not a distributed variable in the model.")
[docs]
def get_dist_sizes(self, node=None):
"""
Get the distributed sizes of a variable.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node.
Returns
-------
dict or list of ints
The full distributed sized dict or a single entry for the given node.
"""
if self._dist_sizes is None:
if self.comm.size > 1:
dshapes = self.get_dist_shapes()
self._dist_sizes = dsizes = {}
for n, shapes in dshapes.items():
sizes = [shape_to_len(shape) if shape is not None else 0
for shape in shapes]
# if total size is 0, dshapes haven't been set yet
if np.sum(sizes) > 0:
dsizes[n] = sizes
else:
self._dist_sizes = {}
if node is None:
return self._dist_sizes
if node in self._dist_sizes:
return self._dist_sizes[node]
if node in self._distributed_nodes:
shapes = self.get_dist_shapes(node)
sizes = [shape_to_len(shape) if shape is not None else 0 for shape in shapes]
self._dist_sizes[node] = sizes
return sizes
else:
raise ValueError(f"Can't get distributed sizes for variable '{node[1]}' because it is "
"not a distributed variable in the model.")
[docs]
def compute_global_shape(self, node):
"""
Compute the global shape of a variable.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
node.
Returns
-------
tuple of ints
The global shape of the variable.
"""
try:
return get_global_dist_shape(self.get_dist_shapes(node))
except ValueError as err:
self._collect_error(f"Can't get global shape of distributed variable '{node[1]}': "
f"{err}", ident=node)
return
[docs]
def add_manual_connections(self, group):
"""
Add manual connections to the graph for the given group.
Parameters
----------
group : Group
The group.
"""
manual_connections = group._manual_connections
resolver = group._resolver
allprocs_discrete_in = group._var_allprocs_discrete['input']
allprocs_discrete_out = group._var_allprocs_discrete['output']
for prom_tgt, (prom_src, src_indices) in manual_connections.items():
src_io = resolver.get_iotype(prom_src)
if src_io is None:
guesses = get_close_matches(prom_src, list(resolver.prom_iter('output')) +
list(allprocs_discrete_out.keys()))
self._bad_conns.add((prom_src, prom_tgt))
group._collect_error(f"{group.msginfo}: Attempted to connect from '{prom_src}' to "
f"'{prom_tgt}', but '{prom_src}' doesn't exist. Perhaps you "
"meant to connect to one of the following outputs: "
f"{guesses}.")
continue
if resolver.is_prom(prom_tgt, 'input'):
tgt_io = 'input'
else:
tgt_io = resolver.get_iotype(prom_tgt)
if tgt_io == 'output':
# check that target is not also an input
self._bad_conns.add((prom_src, prom_tgt))
group._collect_error(f"{group.msginfo}: Attempted to connect from '{prom_src}' to "
f"'{prom_tgt}', but '{prom_tgt}' is an output. All "
"connections must be to an input.")
continue
if tgt_io is None:
guesses = get_close_matches(prom_tgt, list(resolver.prom_iter('input')) +
list(allprocs_discrete_in.keys()))
self._bad_conns.add((prom_src, prom_tgt))
group._collect_error(f"{group.msginfo}: Attempted to connect from '{prom_src}' to "
f"'{prom_tgt}', but '{prom_tgt}' doesn't exist. Perhaps you "
f"meant to connect to one of the following inputs: {guesses}.")
continue
out_comps = {n.rpartition('.')[0] for n in resolver.absnames(prom_src, src_io)}
in_comps = {n.rpartition('.')[0] for n in resolver.absnames(prom_tgt, tgt_io)}
if out_comps & in_comps:
self._bad_conns.add((prom_src, prom_tgt))
group._collect_error(f"{group.msginfo}: Source and target are in the same System "
f"for connection from '{prom_src}' to '{prom_tgt}'.")
continue
src_node, _ = self.get_node_attrs(group.pathname, prom_src, src_io[0])
tgt_node, _ = self.get_node_attrs(group.pathname, prom_tgt, tgt_io[0])
if src_io == 'input' and tgt_io == 'input':
self._input_input_conns.add((src_node, tgt_node))
self.check_add_edge(group, src_node, tgt_node, type='manual', src_indices=src_indices)
[docs]
def gather_data(self, model):
"""
Gather graph node and edge data from all processes for the given model.
Parameters
----------
model : Group
The model.
"""
# we don't have auto-ivcs yet, so some inputs are dangling
myrank = model.comm.rank
vars_to_gather = model._vars_to_gather
nodes_to_send = {}
edges_to_send = {}
in_degree = self.in_degree
nodes = self.nodes
edges = self.edges
for abs_in in chain(model._var_allprocs_abs2meta['input'],
model._var_allprocs_discrete['input']):
if abs_in in vars_to_gather and vars_to_gather[abs_in] == myrank:
in_node = ('i', abs_in)
for inode in self.bfs_up_iter(in_node, include_self=False):
if inode[0] == 'o' or in_degree(inode) == 0:
path = nx.shortest_path(self, inode, in_node)
outstart = inode[0] == 'o'
for i, p in enumerate(path):
if p not in nodes_to_send and (not outstart or i > 0):
nodes_to_send[p] = nodes[p]['attrs'].as_dict()
if i > 0:
edge = (path[i-1], p)
edges_to_send[edge] = edges[edge]
for abs_out in chain(model._var_allprocs_abs2meta['output'],
model._var_allprocs_discrete['output']):
if abs_out in vars_to_gather and vars_to_gather[abs_out] == myrank:
# include other promoted output nodes
for u, v in nx.dfs_edges(self, ('o', abs_out)):
if u not in nodes_to_send:
nodes_to_send[u] = nodes[u]['attrs'].as_dict()
edges_to_send[u, v] = edges[u, v]
if v[0] == 'i':
continue
nodes_to_send[v] = nodes[v]['attrs'].as_dict()
# edges_to_send[u, v] = edges[u, v]
graph_info = model.comm.allgather((nodes_to_send, edges_to_send))
all_abs2meta = model._var_allprocs_abs2meta
all_discrete = model._var_allprocs_discrete
for rank_nodes, _ in graph_info:
for node, data in rank_nodes.items():
if node not in self:
data['remote'] = True
attrs = NodeAttrs()
attrs.update(data)
self.add_node(node, attrs=attrs)
if node[0] == 'i':
na2m = all_abs2meta['input']
ndisc = all_discrete['input']
else:
na2m = all_abs2meta['output']
ndisc = all_discrete['output']
if node[1] in na2m:
nodes[node]['attrs'].meta = na2m[node[1]]
elif node[1] in ndisc:
nodes[node]['attrs'].meta = ndisc[node[1]]
for _, rank_edges in graph_info:
for edge, data in rank_edges.items():
if edge in edges:
if data.get('src_indices', None) is not None:
if edges[edge].get('src_indices', None) is None:
edges[edge]['src_indices'] = data['src_indices']
else:
assert edge[0] in nodes and edge[1] in nodes
self.add_edge(edge[0], edge[1], **data)
[docs]
def resolve_from_children(self, model, src_node, node, auto=False):
"""
Resolve metadata from children for the given node.
This propagates information up the tree from leaf inputs up to either the root input node
or to the root auto_ivc node.
Parameters
----------
model : Group
The model.
src_node : tuple of the form ('i' or 'o', name)
The source node.
node : tuple of the form ('i' or 'o', name)
The node.
auto : bool
Whether the source node is an auto_ivc node.
"""
if self.out_degree(node) == 0: # skip leaf nodes
return
try:
children_meta = \
[(self.nodes[child]['attrs'], self.edges[(node, child)].get('src_indices', None))
for child in self.succ[node]]
node_meta = self.nodes[node]['attrs']
if node[0] == 'i':
remote = all(m.remote for m, _ in children_meta)
else:
remote = False # dont' transfer 'remote' status to connected outputs
node_meta.remote = remote
node_meta.require_connection = any(m.require_connection for m, _ in children_meta)
node_meta.discrete = discrete = self.get_discrete_from_children(model, node,
children_meta)
node_meta.distributed = \
self.get_distributed_from_children(model, node, children_meta, auto,
self.nodes[src_node]['attrs'].distributed)
ambig_units = ambig_val = None
if node[0] == 'o':
for i, tup in enumerate(children_meta):
cm, _ = tup
ambig = cm.ambiguous
if ambig:
nlist = [s for s in self.succ[node]]
if cm.ambiguous_units:
ambig_units = nlist[i]
elif cm.ambiguous_val:
ambig_val = nlist[i]
break
if not discrete:
node_meta.units_by_conn = all(m.units_by_conn for m, _ in children_meta)
if not ambig_units:
node_meta.units = self.get_units_from_children(model, node, children_meta,
node_meta.defaults)
node_meta.shape_by_conn = all(m.shape_by_conn for m, _ in children_meta)
node_meta.shape = \
self.get_shape_from_children(node, children_meta, node_meta.defaults)
if not ambig_val:
val = self.get_val_from_children(model, node, children_meta, node_meta.defaults,
auto)
if val is not None:
if node[1].startswith('_auto_ivc.'):
val = deepcopy(val)
node_meta.val = val
if node_meta._locmeta is not None:
node_meta._locmeta['val'] = val
else:
node_meta.val = val
if ambig_units:
raise ConnError(self.ambig_units_msg(ambig_units))
if ambig_val:
raise ConnError(self.ambig_values_msg(ambig_val))
except Exception as err:
if isinstance(err, ConnError):
model._collect_error(f"{model.msginfo}: {err}", tback=err.__traceback__,
ident=node)
else:
model._collect_error(f"{model.msginfo}: While resolving children of '{node[1]}': "
f"{err}", tback=err.__traceback__,
ident=node)
[docs]
def ambig_units_msg(self, node, incompatible=False):
"""
Generate a message for an ambiguous or incompatible units error.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node.
incompatible : bool
Whether the units are incompatible.
Returns
-------
any
Returned value.
"""
rows = []
self.find_ambiguous_causes(node, rows, 'units')
rows = sorted(rows, key=lambda x: x[0])
table = textwrap.indent(str(generate_table(rows, tablefmt='plain')), ' ')
msg = (f"The following inputs promoted to '{node[1]}' have different "
f"units:\n{table}")
if incompatible:
msg += "\n These units are incompatible."
else:
msg += ("\n Call model.set_input_defaults('"
f"{self.top_name(node)}', units=?) to remove the ambiguity.")
return msg
[docs]
def ambig_shapes_msg(self, node, children_meta):
"""
Generate a message for an ambiguous shapes error.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node.
children_meta : list of tuples of the form (NodeAttrs, list of Indexers or None)
The children metadata and source indices.
Returns
-------
str
The message.
"""
node_meta = self.nodes[node]['attrs']
shapes = [m.shape_from_child(node_meta, src_indices) for m, src_indices in children_meta]
children = [n for _, n in self.succ[node]]
rows = sorted((n, s) for n, s in zip(children, shapes))
table = textwrap.indent(str(generate_table(rows, tablefmt='plain')), ' ')
return (f"The following inputs promoted to '{node[1]}' have different "
f"incompatible shapes:\n{table}")
[docs]
def ambig_values_msg(self, node):
"""
Generate a message for an ambiguous values error.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node.
Returns
-------
str
The message.
"""
causes = []
self.find_ambiguous_causes(node, causes, 'val')
children = [n for n, _ in causes]
child_nodes = [('i', n) for n in children]
causing_meta = [self.nodes[n]['attrs'] for n in child_nodes]
units_list = [m.units for m in causing_meta]
vals = [m.val for m in causing_meta]
ulist = [u if u is not None else '' for u in units_list]
vlist = [truncate_str(v, max_len=60) for v in vals]
rows = sorted((n, u, v) for n, u, v in zip(children, ulist, vlist))
table = textwrap.indent(str(generate_table(rows, tablefmt='plain')), ' ')
return (f"The following inputs promoted to '{node[1]}' have different "
f"values, so the value of '{node[1]}' is ambiguous:\n{table}\n Call "
f"model.set_input_defaults('"
f"{self.top_name(node)}', val=?) to remove the ambiguity.")
[docs]
def find_ambiguous_causes(self, node, causes, data_name):
"""
Find all of the nodes that are causing the ambiguity, starting from an ambiguous node.
Parameters
----------
node : tuple
The node to start from.
causes : list
A list of tuples of the form (node_name, data_value).
data_name : str
The name of the data to find the causes of.
"""
attr = f"ambiguous_{data_name}"
for child in self.succ[node]:
child_meta = self.nodes[child]['attrs']
if getattr(child_meta, attr):
self.find_ambiguous_causes(child, causes, data_name)
else:
causes.append((child[1], getattr(child_meta, data_name)))
[docs]
def get_units_from_children(self, model, node, children_meta, defaults):
"""
Get the units from the children of the given node.
Parameters
----------
model : Group
The model.
node : tuple of the form ('i' or 'o', name)
The node.
children_meta : list of tuples of the form (NodeAttrs, list of Indexers or None)
The children metadata and source indices.
defaults : Defaults
The default metadata for the node.
Returns
-------
str or None
The units from the children of the given node.
"""
start = None
nodes = self.nodes
is_output = node[0] == 'o'
node_ambig = False
child_units_differ = False
unset = object()
start = unset
for chmeta, _ in children_meta:
if chmeta.units_by_conn and self._first_pass:
continue
if start is unset:
start = chmeta.units_from_child()
node_ambig = chmeta.ambiguous_units
continue
u = chmeta.units_from_child()
node_ambig |= chmeta.ambiguous_units
if is_output and node_ambig:
# we want the ambiguous child (input) for error reporting
for s in self.succ[node]:
if nodes[s]['attrs'].ambiguous_units:
raise ConnError(self.ambig_units_msg(s))
one_none = (start is not None and u is None) or (start is None and u is not None)
child_units_differ |= start != u
if not one_none and not is_compatible(start, u):
raise ConnError(self.ambig_units_msg(node, incompatible=True))
if start is unset:
start = None
if is_output:
return start
if defaults.units is not None:
# setting default units removes any ambiguity
nodes[node]['attrs'].ambiguous_units = False
return defaults.units
if child_units_differ:
nodes[node]['attrs'].ambiguous_units = True
else:
# if a child is ambiguous, this node is also ambiguous if the default units are not set
if node_ambig:
nodes[node]['attrs'].ambiguous_units = True
return start
[docs]
def get_shape_from_children(self, node, children_meta, defaults):
"""
Get the shape from the children of the given node.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node.
children_meta : list of tuples of the form (NodeAttrs, list of Indexers or None)
The children metadata and source indices.
defaults : Defaults
The default metadata for the node.
Returns
-------
tuple or None
The shape from the children of the given node.
"""
node_meta = self.nodes[node]['attrs']
start = None
for chmeta, src_indices in children_meta:
shape = chmeta.shape_from_child(node_meta, src_indices)
val = chmeta.val_from_child(node_meta, src_indices)
if shape is None and val is not None and np.ndim(val) > 0:
if src_indices is None:
shape = np.shape(val)
if shape is not None:
if start is None:
start = shape
else:
if not array_connection_compatible(start, shape):
raise ConnError(self.ambig_shapes_msg(node, children_meta))
if defaults.val is not None and np.ndim(defaults.val) > 0:
return np.shape(defaults.val)
return start
[docs]
def get_val_from_children(self, model, node, children_meta, defaults, auto):
"""
Get the value from the children of the given node.
Parameters
----------
model : Group
The model.
node : tuple of the form ('i' or 'o', name)
The node.
children_meta : list of tuples of the form (NodeAttrs, list of Indexers or None)
The children metadata and source indices.
defaults : Defaults
The default metadata for the node.
auto : bool
Whether the source node is an auto_ivc node.
Returns
-------
any or None
The value from the children of the given node.
"""
node_meta = self.nodes[node]['attrs']
start = None
unshaped_scalar = False
for i, tup in enumerate(children_meta):
ch_meta, src_indices = tup
val = ch_meta.val_from_child(node_meta, src_indices)
if ch_meta.ambiguous_val:
node_meta.ambiguous_val = True
if auto:
continue
if val is not None:
if start is None or (unshaped_scalar and ch_meta.shape is not None):
start = val
start_type = type(start)
start_units = ch_meta.units_from_child()
unshaped_scalar = ch_meta.shape is None and np.ndim(start) == 0
elif auto: # values must be the same or value of auto_ivc will be ambiguous
if ch_meta.discrete:
if start_type is not type(val):
node_meta.ambiguous_val = True
continue
if isinstance(start, np.ndarray):
if not np.all(start == val):
node_meta.ambiguous_val = True
continue
if start != val:
node_meta.ambiguous_val = True
else: # continuous
if has_val_mismatch(start_units, start, ch_meta.units, val):
node_meta.ambiguous_val = True
elif not (ch_meta.shape is None and np.ndim(val) == 0):
if not are_compatible_values(start, val, ch_meta.discrete,
src_indices=src_indices):
slist = list(self.succ[node])
raise ConnError(self.value_error(True, slist[i], node, start, val))
if defaults.val is not None:
node_meta.ambiguous_val = False
return defaults.val
if not node_meta.ambiguous_val:
return start
[docs]
def get_discrete_from_children(self, group, node, children_meta):
"""
Get the discrete flag from the children of the given node.
Parameters
----------
group : Group
The group.
node : tuple of the form ('i' or 'o', name)
The node.
children_meta : list of tuples of the form (NodeAttrs, list of Indexers or None)
The children metadata and source indices.
Returns
-------
bool
The discrete flag from the children of the given node.
"""
discretes = [m.discrete for m, _ in children_meta]
dset = set(discretes)
if len(dset) == 1:
discrete = dset.pop()
if discrete:
node_meta = self.nodes[node]['attrs']
if node_meta.defaults.units is not None:
group._collect_error(f"Cannot set 'units={node_meta.defaults.units}' for "
f"discrete variable '{node[1]}'.")
if node_meta.defaults.src_shape is not None:
group._collect_error(f"Cannot set 'shape={node_meta.defaults.src_shape}' for "
f"discrete variable '{node[1]}'.")
return discrete
else:
slist = list(self.succ[node])
discs = [s for s, d in zip(slist, discretes) if d]
non_discs = [s for s, d in zip(slist, discretes) if not d]
raise ConnError(f"'{node[1]}' has discrete ({sorted(discs)}) and non-discrete "
f"({sorted(non_discs)}) children.")
[docs]
def get_distributed_from_children(self, model, node, children_meta, auto, src_distributed):
"""
Get the distributed flag from the children of the given node.
Parameters
----------
model : Group
The model.
node : tuple of the form ('i' or 'o', name)
The node.
children_meta : list of tuples of the form (NodeAttrs, list of Indexers or None)
The children metadata and source indices.
auto : bool
Whether the source node is an auto_ivc node.
src_distributed : bool
Whether the source node is distributed.
Returns
-------
bool or None
The distributed flag from the children of the given node.
"""
# A parent is only desginated as distributed if it has only one child and that child is
# distributed.
dist = set()
has_src_indices = False
for m, src_indices in children_meta:
dist.add(m.distributed)
if src_indices is not None:
has_src_indices = True
if auto and True in dist:
return None # error will be collected later
if len(dist) > 1:
return None # leave dist ambiguous
ret = dist.pop()
if ret is False or not src_distributed:
return False
if has_src_indices or len(children_meta) > 1:
return None # leave dist ambiguous
return ret
[docs]
def get_parent_val_shape_units(self, parent, child):
"""
Get the value, shape, and units from the parent of the given child node.
Parameters
----------
parent : tuple of the form ('i' or 'o', name)
The parent node.
child : tuple of the form ('i' or 'o', name)
child.
Returns
-------
any
The value, shape, and units from the parent of the given child node.
"""
parent_meta = self.nodes[parent]['attrs']
src_indices = self.edges[(parent, child)].get('src_indices', None)
val = parent_meta.val
shape = parent_meta.shape
if not (src_indices is None or shape is None):
if src_indices._src_shape is None:
src_indices.set_src_shape(shape)
shape = src_indices.indexed_src_shape
if val is not None:
val = src_indices.indexed_val(np.atleast_1d(val))
if parent_meta.ambiguous_val:
val = None
units = None if parent_meta.ambiguous_units else parent_meta.units
return val, shape, units
[docs]
def check_src_to_tgt_indirect(self, model, src, tgt, src_shape, tgt_shape):
"""
Check compatibility between nodes that are not directly connected.
Parameters
----------
model : Group
model.
src : tuple of the form ('i' or 'o', name)
The source node.
tgt : tuple of the form ('i' or 'o', name)
The target node.
src_shape : tuple
The source shape.
tgt_shape : tuple
The target shape.
"""
src_meta = self.nodes[src]['attrs']
tgt_meta = self.nodes[tgt]['attrs']
if src_shape is not None and not (src_meta.distributed and src_meta.remote):
if tgt_shape is not None:
if not (tgt_meta.distributed and tgt_meta.remote):
if not array_connection_compatible(src_shape, tgt_shape):
self._collect_error(self.shape_error(src, tgt, src_shape, tgt_shape))
return
elif not tgt_meta.ambiguous_val:
tgt_meta.shape = src_shape
src_units = src_meta.units
tgt_units = tgt_meta.units
if src_units is not None:
if tgt_units is not None:
if tgt_units != 'ambiguous' and not is_compatible(src_units, tgt_units):
self._collect_error(self.units_error(False, src, tgt, src_units, tgt_units))
elif tgt_meta.units_by_conn:
tgt_meta.units = src_units
[docs]
def get_dist_offset(self, node, rank, flat):
"""
Get the distributed offset and size of axis 0 of the given node.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node.
rank : int
rank.
flat : bool
Whether the source is flat.
Returns
-------
tuple
The distributed offset and size of axis 0 of the given node.
"""
offset = 0
for i, dshape in enumerate(self._dist_shapes[node]):
if i == rank:
break
if dshape is not None:
if flat:
offset += shape_to_len(dshape)
else:
offset += dshape[0] if len(dshape) > 0 else 1
if dshape is None:
return offset, 0
if flat:
sz = shape_to_len(dshape)
else:
sz = dshape[0] if len(dshape) > 0 else 1
return offset, sz
[docs]
def check_dist_connection(self, model, src_node):
"""
Check a connection starting at src where src and/or a target is distributed.
Parameters
----------
model : Group
The model.
src_node : tuple of the form ('i' or 'o', name)
The source node.
"""
nodes = self.nodes
src_meta = nodes[src_node]['attrs']
if src_meta.shape is None:
return # an earlier error has occurred, so skip
src_dist = src_meta.distributed
if src_dist:
if src_meta.global_shape is None:
src_meta.global_shape = self.compute_global_shape(src_node)
src_inds_shape = src_meta.global_shape
else:
src_inds_shape = src_meta.shape
leaves = list(self.leaf_input_iter(src_node))
for tgt_node in leaves:
tgt_meta = nodes[tgt_node]['attrs']
tgt_dist = tgt_meta.distributed
src_inds_list = tgt_meta.src_inds_list
if tgt_meta.distributed:
if tgt_meta.global_shape is None:
tgt_meta.global_shape = self.compute_global_shape(tgt_node)
if src_dist:
if tgt_dist: # dist --> dist
src_shape = src_meta.shape
tgt_shape = tgt_meta.shape
if not src_inds_list:
# no src_indices, so shape of dist src must match shape of dist tgt on
# each rank, and we must specify src_indices to match src and tgt on
# each rank.
if tgt_shape is not None:
offset, sz = self.get_dist_offset(tgt_node, self.comm.rank, False)
if sz is not None:
if sz == 0:
src_indices = \
indexer(slice(0, 0), flat_src=True, src_shape=(0,))
else:
src_indices = \
indexer(slice(offset, offset + sz),
flat_src=False, src_shape=src_meta.global_shape)
path = nx.shortest_path(self, src_node, tgt_node)
self.edges[(path[-2], tgt_node)]['src_indices'] = src_indices
tgt_meta.src_inds_list = src_inds_list = [src_indices]
else: # dist --> serial
if not src_inds_list:
self._collect_error(f"Can't automatically determine src_indices for "
f"connection from distributed variable '{src_node[1]}'"
f" to serial variable '{tgt_node[1]}'.")
return
src_shape = src_meta.global_shape
tgt_shape = tgt_meta.shape
else:
src_shape = src_meta.shape
tgt_shape = tgt_meta.shape
if src_shape is None:
# an earlier error has occurred, so skip
return
if tgt_dist: # serial --> dist
# src_inds_shape = (shape_to_len(src_shape) * model.comm.size,)
if not src_inds_list:
# we can automatically add src_indices to match up this distributed target
# with the serial source
if tgt_shape is not None and src_shape == tgt_meta.global_shape:
offset, sz = self.get_dist_offset(tgt_node, model.comm.rank, False)
src_indices = \
indexer(slice(offset, offset + sz),
flat_src=False,
src_shape=src_inds_shape)
path = nx.shortest_path(self, src_node, tgt_node)
self.edges[(path[-2], tgt_node)]['src_indices'] = src_indices
tgt_meta.src_inds_list = src_inds_list = [src_indices]
else: # serial --> serial
pass
try:
for i, src_inds in enumerate(src_inds_list):
if i == 0:
src_inds.set_src_shape(src_inds_shape)
else:
src_inds.set_src_shape(src_shape)
src_shape = src_inds.indexed_src_shape
except Exception as err:
if not src_meta.dyn_shape: # if dyn_shape, error is already collected
self._collect_error(f"Error in connection between '{src_node[1]}' and "
f"'{tgt_node[1]}': {err}")
return
self.check_src_to_tgt_indirect(model, src_node, tgt_node, src_shape, tgt_shape)
[docs]
def resolve_output_to_output_down(self, parent, child):
"""
Resolve metadata for a target output node based on the metadata of a parent output node.
Parameters
----------
model : Group
The model.
parent : tuple of the form ('o', name)
The source output node.
child : tuple of the form ('i' or 'o', name)
The child output node.
"""
child_meta = self.nodes[child]['attrs']
parent_meta = self.nodes[parent]['attrs']
if parent_meta['discrete']:
for key in _discrete_copy_meta:
setattr(child_meta, key, getattr(parent_meta, key))
else:
for key in _continuous_copy_meta:
setattr(child_meta, key, getattr(parent_meta, key))
if parent_meta.distributed:
shapes = self.get_dist_shapes()
shapes[child] = shapes[parent]
[docs]
def resolve_down(self, model, parent, child, auto):
"""
Update the child node's metadata based on the parent node's metadata.
Parameters
----------
model : Model
The model.
parent : tuple
The parent node.
child : tuple
The target node.
auto : bool
Whether the source node of the connection tree is an auto_ivc node.
"""
try:
if parent[0] == 'o':
if child[0] == 'i':
return self.resolve_output_input_connection(model, parent, child)
else:
self.resolve_output_to_output_down(parent, child)
else:
self.resolve_input_to_input_down(model, parent, child, auto)
except Exception as err:
self.handle_error(False, parent, child, exc=err)
return True
[docs]
def resolve_conn_tree(self, model, src_node):
"""
Resolve the connection tree rooted at src_node.
Metadata is first propagated up the tree from the absolute input nodes up to the root
input node. For auto_ivc rooted trees, the propagation continues to the root auto_ivc node.
Then, checking of compatability between nodes is performed from parent to child down the
tree, and in some cases metadata is propagated down the tree as well.
Parameters
----------
model : Model
The model.
src_node : tuple
The source node. This is always an absolute output node.
"""
if src_node in self._resolved:
return
nodes = self.nodes
auto = src_node[1].startswith('_auto_ivc.')
src_meta = nodes[src_node]['attrs']
dynamic = src_meta.dynamic
src_dist = src_meta.distributed
first_pass = self._first_pass
has_dist = False
if src_dist:
self._dist_nodes.add(src_node)
has_dist = True
dist_nodes = self._dist_nodes
abs2meta_in = model._var_allprocs_abs2meta['input']
# first, resolve inputs up from the bottom of the tree to the root input node.
for node in dfs_postorder_nodes(self, src_node):
if src_dist and first_pass and node[1] in abs2meta_in:
self._dist_nodes.update(nx.shortest_path(self, src_node, node))
if node[0] == 'i':
node_meta = nodes[node]['attrs']
if node_meta.dynamic:
# if tree contains no dynamic nodes then it's resolved after the first pass
dynamic = True
if first_pass and node[1] in abs2meta_in:
if not src_dist and node_meta.distributed:
has_dist = True
self._dist_nodes.update(nx.shortest_path(self, src_node, node))
if auto:
self._dist_nodes.add(src_node)
self.resolve_from_children(model, src_node, node, auto=auto)
# resolve auto_ivc node (these are never promoted so there is always only one output node)
if auto:
# if under MPI, check for ambiguity of value in root input node
if model.comm.size > 1:
leaves = list(self.leaf_input_iter(src_node))
if len(leaves) > 1:
for src_node, tgt_node in dfs_edges(self, src_node):
if tgt_node[0] == 'i':
# root input
inp_meta = nodes[tgt_node]['attrs']
if inp_meta.ambiguous_units:
break
if inp_meta.remote:
val = None
elif inp_meta.discrete:
val = inp_meta.val
else: # local and numerical
if isinstance(inp_meta.val, np.ndarray):
val = array_hash(inp_meta.val)
else:
val = inp_meta.val
ambig = False
tups = model.comm.gather((val, inp_meta.remote), root=0)
if model.comm.rank == 0:
start = None
for v, remote in tups:
if remote:
continue
if start is None:
start = v
else:
if start != v:
ambig = True
break
model.comm.bcast(ambig, root=0)
else:
ambig = model.comm.bcast(None, root=0)
if ambig:
inps = sorted([l for _, l in leaves])
model._collect_error(
f"The inputs {inps}, promoted to '{tgt_node[1]}' have "
f"different values, so the value of '{tgt_node[1]}' is "
"ambiguous. Call model.set_input_defaults('"
f"{self.top_name(tgt_node)}', val=?) to remove the ambiguity.")
break
self.resolve_from_children(model, src_node, src_node, auto=auto)
if src_meta.val is None:
if src_meta.shape is not None:
src_meta.val = np.ones(src_meta.shape)
# now try filling in any missing metadata going down the tree. This can happen if
# for example there are src_indices that block propagation of val and shape below a node
# where shape or val has been set by set_input_defaults. This will also fill in
# missing metadata for promoted output nodes, and it can also set shapes for
# shape_by_conn inputs. Also perform compatibility checks between parent and child nodes.
if not (dynamic and first_pass):
for edge in dfs_edges(self, src_node):
u, v = edge
if v[0] == 'i' and u in dist_nodes and v in dist_nodes:
# if the edge is along a path between a src and leaf node that are dist-dist,
# dist-serial, or serial-dist, then we need to check the compatibility of the
# connection taking distributed shapes into account.
continue
else:
self.resolve_down(model, u, v, auto)
if has_dist:
self.check_dist_connection(model, src_node)
if first_pass and not dynamic:
self._resolved.add(src_node)
[docs]
def add_auto_ivc_nodes(self, model):
"""Add auto IVC nodes.
Parameters
----------
model : Group
model.
Returns
-------
list of tuples of the form ('o', name)
The auto IVC nodes added.
"""
assert model.pathname == ''
in_degree = self.in_degree
pred = self.predecessors
# traverse up the tree from each absolute input node and find its src if it has one
dangling_inputs = set()
above_anchors = set() # input nodes directly above an anchor node
for name in chain(model._var_allprocs_abs2meta['input'],
model._var_allprocs_discrete['input']):
start_node = ('i', name)
src = None
stack = [start_node]
while stack:
node = stack.pop()
preds = list(pred(node))
outs = [p for p in preds if p[0] == 'o']
ins = [p for p in preds if p[0] == 'i']
if outs:
assert src is None and len(outs) == 1
src = outs[0]
above_anchors.update(ins)
continue
stack.extend(ins)
if src is None:
dangling_inputs.add(node)
final_dangling_inputs = []
if dangling_inputs:
# find the ancestor node of each dangling input
for d in dangling_inputs:
stack = [d]
found = False
while stack:
node = stack.pop()
if in_degree(node) == 0:
final_dangling_inputs.append(node)
found = True
break
preds = list(pred(node))
assert len(preds) == 1
p = preds[0]
if p in above_anchors:
final_dangling_inputs.append(node)
found = True
break
stack.append(p)
if not found:
final_dangling_inputs.append(d)
auto_nodes = []
for i, n in enumerate(sorted(final_dangling_inputs, key=lambda x: x[1])):
auto_node, _ = self.get_node_attrs('', f'_auto_ivc.v{i}', 'o')
self.add_edge(auto_node, n, type='manual')
auto_nodes.append(auto_node)
return auto_nodes
[docs]
@collect_errors
def final_check(self, model):
"""
Check for any errors after all updates are complete.
Parameters
----------
model : Group
The top level group.
"""
nodes = self.nodes
for abs_out in self._var_allprocs_abs2meta['output']:
node = ('o', abs_out)
auto = node[1].startswith('_auto_ivc.')
for u, v in dfs_edges(self, node):
vmeta = nodes[v]['attrs']
if auto and vmeta.distributed:
for _, t in self._bad_conns:
if t[1] == v[1]:
break
else:
iroot = self.input_root(v)
if iroot != v:
promas = f", promoted as '{iroot[1]}',"
else:
promas = ""
self._collect_error(f"Distributed input '{v[1]}'{promas} is not "
"connected. Declare an IndepVarComp and connect it "
"to this input to eliminate this error.",
ident=(u, v))
continue
if u[0] == 'o' and v[0] == 'i': # an output to input connection
uunits = nodes[u]['attrs'].units
vunits = vmeta.units
if uunits is None or vunits is None:
uunitless = _is_unitless(uunits)
vunitless = _is_unitless(vunits)
if uunitless and not vunitless:
issue_warning(f"{model.msginfo}: Input '{v[1]}' with units of "
f"'{vunits}' is connected to output '{u[1]}' "
f"which has no units.", category=UnitsWarning)
elif not uunitless and vunitless:
if not nodes[v]['attrs'].ambiguous_units:
issue_warning(f"{model.msginfo}: Output '{u[1]}' with units of "
f"'{uunits}' is connected to input '{v[1]}' "
f"which has no units.", category=UnitsWarning)
desvars = model.get_design_vars()
for req in self._required_conns:
node = ('i', req)
root = self.get_root(node)
if root[1].startswith('_auto_ivc.'):
iroot = self.input_root(node)
if iroot[1] in desvars: # design vars are 'connected'
continue
if iroot == node:
promstr = ''
else:
promstr = f", promoted as '{iroot[1]}',"
self._collect_error(f"{self.msginfo}: Input '{req}'{promstr} requires a "
f"connection but is not connected.")
[docs]
def add_implicit_connections(self, model, implicit_conn_vars):
# implicit connections are added after all promotions are added, so any implicitly connected
# nodes are guaranteed to already exist in the graph.
"""Add implicit connections.
Parameters
----------
model : Group
The top level group.
implicit_conn_vars : list of tuples of the form (pathname, relname)
implicit conn vars.
"""
for pathname, relname in implicit_conn_vars:
onode, _ = self.get_node_attrs(pathname, relname, 'o')
inode, _ = self.get_node_attrs(pathname, relname, 'i')
self.check_add_edge(model, onode, inode, type='implicit')
[docs]
def update_src_inds_lists(self, model):
# propagate src_indices down the tree, but don't update shapes because we don't
# know all of the shapes at the root and leaves of the tree yet.
# Also, determine the list of dangling promoted inputs for later processing.
"""Update source indices lists.
Parameters
----------
model : Group
The top level group.
"""
edges = self.edges
nodes = self.nodes
self._dangling_prom_inputs = dangling = set()
for node in self.nodes():
if self.in_degree(node) == 0:
if node[0] == 'o':
for u, v in dfs_edges(self, node):
if v[0] == 'i':
edge_meta = edges[u, v]
src_inds = edge_meta.get('src_indices', None)
src_inds_list = nodes[u]['attrs'].src_inds_list
if src_inds is not None:
src_inds_list = src_inds_list.copy()
src_inds_list.append(src_inds)
nodes[v]['attrs'].src_inds_list = src_inds_list
else:
dangling.add(node)
[docs]
def create_all_conns_dict(self, model):
"""
Create a dict of global connections 'owned' by a group, keyed by the group's pathname.
This should only be called on the top level group.
Parameters
----------
model : Group
The top level group.
Returns
-------
dict
A dict of connections 'owned' by a group, of the form
{group_pathname: {abs_in: abs_out}}.
"""
assert model.pathname == ''
conns = {}
abs_ins = model._var_allprocs_abs2meta['input']
discrete_ins = model._var_allprocs_discrete['input']
global_conns = {}
for abs_out in model._resolver.abs_iter('output'):
node = ('o', abs_out)
for _, abs_in in self.leaf_input_iter(node):
assert abs_in in abs_ins or abs_in in discrete_ins
if abs_in in global_conns:
global_conns[abs_in].append(abs_out)
else:
global_conns[abs_in] = [abs_out]
common = common_subpath((abs_out, abs_in))
if common not in conns:
conns[common] = {}
conns[common][abs_in] = abs_out
multiple_conns = []
for abs_in, abs_outs in global_conns.items():
if len(abs_outs) > 1:
multiple_conns.append((abs_in, abs_outs))
if multiple_conns:
msg = []
for abs_in, abs_outs in multiple_conns:
msg.append(f"'{abs_in}' from {abs_outs}")
self._collect_error(f"{self.msginfo}: The following inputs have multiple connections:"
f" {', '.join(msg)}.")
return conns
[docs]
def get_root(self, node):
"""Get root.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The starting node..
Returns
-------
tuple of the form ('i' or 'o', name) or None
The root node.
"""
in_degree = self.in_degree
for n in self.bfs_up_iter(node):
if n[0] == 'o' and in_degree(n) == 0:
return n
# if we get here, node is an input promoted above the attachment point of its source output
# or an absolute input
if in_degree(node) == 0 and self.out_degree(node) == 0:
return node # unconnected node
for n in self.bfs_down_iter(node, include_self=False):
# look for output predecessors
for p in self.predecessors(n):
if p[0] == 'o':
return self.get_root(p)
[docs]
def get_tree_iter(self, node):
"""
For any node in the graph, yield a generator of all nodes in the tree the node belongs to.
The nodes are yielded in depth-first preorder.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node to get the tree for.
Yields
------
tuple of the form (node, depth)
The node and its depth.
"""
root = self.get_root(node)
if root is None:
if node[0] == 'i':
# may have a manual connection before a late promotion, so
# traverse down the tree to the end, then back to the root from there
for n in self.leaf_input_iter(node):
root = self.get_root(n)
break
if root is None:
raise RuntimeError(f"Can't find root of connection tree from node '{node}'.")
dq = deque([(root, 0)])
while dq:
node, depth = dq.popleft()
yield (node, depth)
next_depth = depth + 1
for child in self.successors(node):
dq.appendleft((child, next_depth))
[docs]
def absnames(self, node):
"""
Get the absolute names associated with a node.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node to get the absolute names for.
Returns
-------
list of str
The absolute names associated with the node.
"""
if node[0] == 'i':
return [n for _, n in self.leaf_input_iter(node)]
else:
return [self.get_root(node)[1]]
[docs]
def get_subarray(self, arr, indices_list):
"""
Apply a sequence of indexing operations to the input array.
Parameters
----------
arr : numpy.ndarray
The initial array.
indices_list : list
A list of indexing objects (e.g., slices, integers, arrays for advanced indexing).
Returns
-------
subarray : numpy.ndarray
The result after applying all indexing operations (may be a view or copy).
"""
current = np.atleast_1d(arr)
if indices_list is not None:
for idx in indices_list:
current = idx.indexed_val(current)
return current
[docs]
def set_subarray(self, arr, indices_list, val, node):
"""
Set the val into the positions of the original array based on indices_list.
This function handles both views and copies by propagating changes back through the chain.
Parameters
----------
arr : numpy.ndarray
The original array to modify.
indices_list : list
A list of indexing objects (e.g., slices, integers, arrays for advanced indexing).
val : numpy.ndarray or compatible
The val to set (must match the shape of the final subarray).
node : tuple
The node to set the val for.
Raises
------
ValueError
If the shape of val does not match the final subarray shape.
"""
msg = ''
try:
if indices_list:
chain = [arr]
for idx in indices_list:
chain.append(idx.indexed_val(chain[-1]))
if np.shape(val) != () and np.squeeze(val).shape != np.squeeze(chain[-1]).shape:
msg = (f"Value shape {np.squeeze(val).shape} does not match shape "
f"{np.squeeze(chain[-1]).shape} of the destination")
else:
try:
arr[:] = val
except ValueError:
arr[:] = val.reshape(arr.shape)
return
except Exception as err:
msg = str(err)
if msg:
raise ValueError(f"Failed to set value of '{node[1]}': {msg}.")
last = chain[-1]
if (isinstance(last, np.ndarray) and last.ndim == 0) or np.isscalar(last):
if isinstance(val, np.ndarray) and val.size == 1:
val = val[0]
chain[-1] = val
else:
last[:] = val
for i in range(len(chain) - 2, -1, -1):
sub = chain[i + 1]
prev = chain[i]
idx = indices_list[i]
if sub.base is not prev:
idx.indexed_val_set(prev, sub)
[docs]
def get_src_index_array(self, abs_in):
"""
Get the source index array for a given absolute input name.
Parameters
----------
abs_in : str
The absolute input name.
Returns
-------
numpy.ndarray or None
The source index array or None if the input is not found.
"""
node = ('i', abs_in)
if node not in self:
raise ValueError(f"Input '{abs_in}' not found.")
src_inds_list = self.nodes[node]['attrs'].src_inds_list
if not src_inds_list:
return None
elif len(src_inds_list) == 1:
return src_inds_list[0].shaped_array()
else:
root = self.get_root(node)
root_meta = self.nodes[root]['attrs']
if root_meta.distributed:
root_shape = root_meta.global_shape
else:
root_shape = root_meta.shape
arr = np.arange(shape_to_len(root_shape)).reshape(root_shape)
for inds in src_inds_list:
arr = inds.indexed_val(arr)
return arr
[docs]
def convert_get(self, node, val, src_units, tgt_units, src_inds_list=(), units=None,
indices=None, get_remote=False):
"""
Convert the value for a given node based on units and indices.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node to convert the value for.
val : numpy.ndarray
The value to convert.
src_units : str
The source units.
tgt_units : str
The target units.
src_inds_list : list of Indexer
The source indices list.
units : str or None
The units to convert to.
indices : Indexer or None
The indices to convert.
get_remote : bool
Whether to get the remote value.
Returns
-------
numpy.ndarray
The converted value.
"""
node_meta = self.nodes[node]['attrs']
if not node_meta.discrete:
val = np.asarray(val)
if src_inds_list:
val = self.get_subarray(val, src_inds_list)
if get_remote and not src_inds_list:
val = val.reshape(node_meta.global_shape)
else:
val = val.reshape(node_meta.shape)
if indices:
val = self.get_subarray(val, [indices])
if units is None:
units = tgt_units
if units is not None:
if src_units is None:
raise TypeError(f"Can't express value with units of '{src_units}' in units of "
f"'{units}'.")
elif src_units != units:
try:
scale, offset = unit_conversion(src_units, units)
except Exception:
raise TypeError(f"Can't express value with units of '{src_units}' in units of "
f"'{units}'.")
return (val + offset) * scale
return val
[docs]
def convert_set(self, val, src_units, tgt_units, src_inds_list=(), units=None, indices=None):
"""
Convert the value for a given node based on units and indices.
Parameters
----------
val : numpy.ndarray
val.
src_units : str
The source units.
tgt_units : str
The target units.
src_inds_list : list of Indexer
The source indices list.
units : str or None
The units to convert to.
indices : Indexer or None
The indices to convert.
Returns
-------
any
Returned value.
"""
if indices:
src_inds_list = list(src_inds_list) + [indices]
val = self.get_subarray(val, src_inds_list)
if units is None:
units = tgt_units
if units is not None:
if src_units is None:
raise TypeError(f"Can't express value with units of '{src_units}' in units of "
f"'{units}'.")
elif src_units != units:
try:
scale, offset = unit_conversion(units, src_units)
except Exception:
raise TypeError(f"Can't express value with units of '{src_units}' in units of "
f"'{units}'.")
return (val + offset) * scale
return val
[docs]
def setup_global_connections(self, model):
"""
Compute dict of all connections between inputs and outputs.
"""
from openmdao.core.group import Group
# add nodes for all absolute inputs and connected absolute outputs
self.add_model_vars(model)
systems = list(model.system_iter(include_self=True, recurse=True))
groups = [s for s in systems if isinstance(s, Group)]
for g in groups:
self.add_manual_connections(g)
for g in groups:
self.add_group_input_defaults(g)
self.add_implicit_connections(model, model._get_implicit_connections())
if model.comm.size > 1:
self.gather_data(model)
# self.add_implicit_connections(model, model._get_implicit_connections())
# check for cycles
if not nx.is_directed_acyclic_graph(self):
cycle_edges = nx.find_cycle(self, orientation='original')
errmsg = '\n'.join([f' {edge[0]} ---> {edge[1]}'
for edge in cycle_edges])
self._collect_error('Cycle detected in input-to-input connections. '
f'This is not allowed.\n{errmsg}')
self.add_auto_ivc_nodes(model)
self.update_src_inds_lists(model)
model._setup_auto_ivcs()
self._var_allprocs_abs2meta = model._var_allprocs_abs2meta
self._var_allprocs_discrete = model._var_allprocs_discrete
self._var_allprocs_abs2idx = model._var_allprocs_abs2idx
self._var_abs2meta = model._var_abs2meta
# update global shapes for distributed vars, now that data structures include
# auto_ivcs
for dnode in self._distributed_nodes:
node_meta = self.nodes[dnode]['attrs']
if not node_meta.dyn_shape:
node_meta.global_shape = self.compute_global_shape(dnode)
self.update_all_node_meta(model)
self.transform_input_input_connections(model)
conn_dict = self.create_all_conns_dict(model)
global_conn_dict = {'': {}}
root_dict = global_conn_dict['']
for path, conn_data in conn_dict.items():
for name in all_ancestors(path):
if name in global_conn_dict:
global_conn_dict[name].update(conn_data)
else:
global_conn_dict[name] = conn_data.copy()
# don't forget the root path!
root_dict.update(conn_data)
for system in systems:
if isinstance(system, Group):
system._conn_abs_in2out = conn_dict.get(system.pathname, {})
system._conn_global_abs_in2out = global_conn_dict.get(system.pathname, {})
system._resolver._conns = model._conn_global_abs_in2out
[docs]
def create_node_label(self, node):
"""Create the label for a displayed node.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node.
Returns
-------
str
Returned html for the node label.
"""
def get_table_row(name, meta, mods=(), max_width=None, show_always=False):
"""
Get the html for a table row.
Parameters
----------
name : str
The name of the attribute.
meta : NodeAttrs
The metadata for the node.
mods : list of str
Modifiers for the html.
max_width : int or None
The maximum width of the content.
show_always : bool
Whether to show the content always. If True, the content will be shown regardless of
whether it is None, False, or an empty list.
Returns
-------
str
The html for the table row.
"""
if '.' in name:
parent, _, child = name.rpartition('.')
meta = getattr(meta, parent)
content = getattr(meta, child)
elif isinstance(meta, dict):
content = meta[name]
else:
content = getattr(meta, name)
ambig = getattr(meta, f"ambiguous_{name}")
if ambig:
content = '?'
if (content is None or content is False or (isinstance(content, list) and not content))\
and not show_always:
return ''
if max_width is not None:
content = truncate_str(content, max_width)
if mods:
starts = []
ends = []
for mod in mods:
starts.append(f"<{mod}>")
ends.append(f"</{mod}>")
content = ''.join(starts) + content + ''.join(ends)
line = f'<tr><td align="left"><b>{name}:</b></td><td fixedsize="false" align="left">'\
f'{content}</td></tr>'
return line
name = node[1]
meta = self.nodes[node]['attrs']
rows = []
rows.append(get_table_row('units', meta, ('i',)))
rows.append(get_table_row('shape', meta))
rows.append(get_table_row('val', meta, max_width=max(30, int(len(name) * 1.2))))
rows.append(get_table_row('defaults.val', meta, max_width=max(30, int(len(name) * 1.2))))
rows.append(get_table_row('defaults.src_shape', meta))
rows.append(get_table_row('defaults.units', meta))
rows.append(get_table_row('src_inds_list', meta))
rows.append(get_table_row('shape_by_conn', meta))
rows.append(get_table_row('copy_shape', meta))
# dot doesn't like node labels containing functions, so just show yes if set
rows.append(get_table_row('compute_shape',
{'compute_shape': 'yes' if meta.compute_shape else None}))
rows.append(get_table_row('compute_units',
{'compute_units': 'yes' if meta.compute_units else None}))
rows.append(get_table_row('units_by_conn', meta))
rows.append(get_table_row('copy_units', meta))
rows.append(get_table_row('discrete', meta))
rows.append(get_table_row('distributed', meta))
rows.append(get_table_row('remote', meta))
rows = [r for r in rows if r]
if rows:
combined = '\n'.join(rows)
else:
combined = ''
label = (f'<table border="0" cellborder="0" cellspacing="0" cellpadding="0">'
f'<tr><td align="left" colspan="2"><b>{name}</b></td></tr>'
f'{combined}</table>')
return label
[docs]
def drawable_node_iter(self, pathname=''):
"""
Yield nodes usable in a pydot graph.
Parameters
----------
pathname : str
The pathname of the group to draw.
Yields
------
tuple of the form (node, data)
The node and its metadata.
"""
if pathname:
pathname = pathname + '.'
for node, data in self.nodes(data=True):
meta = data['attrs']
if pathname and not self.startswith(pathname, node):
continue
newdata = {}
if node[0] == 'i':
newdata['fillcolor'] = GRAPH_COLORS['input']
else:
newdata['fillcolor'] = GRAPH_COLORS['output']
if meta.ambiguous:
newdata['color'] = GRAPH_COLORS['ambiguous']
newdata['penwidth'] = '4' # Thick border
newdata['html_label'] = self.create_node_label(node)
newdata['tooltip'] = (meta.pathname, meta.rel_name)
newdata['style'] = 'filled,rounded'
newdata['shape'] = 'box' # Use box shape with rounded corners
newdata['pathname'] = meta.pathname
newdata['rel_name'] = meta.rel_name
yield node, newdata
[docs]
def drawable_edge_iter(self, pathname='', show_cross_boundary=True, max_width=50):
"""
Yield edges usable in a pydot graph.
Parameters
----------
pathname : str, optional
The pathname of the group to draw.
show_cross_boundary : bool, optional
Whether to show cross boundary connections.
max_width : int or None
The maximum allowable width of an edge label.
Yields
------
tuple of the form (u, v, data)
The edge and its metadata.
"""
if pathname:
pathname = pathname + '.'
# None means the edge is a promotion
type_map = {None: 'dashed', 'manual': None, 'implicit': 'dotted'}
for u, v, data in self.edges(data=True):
style = type_map[data.get('type')]
if pathname:
u_internal = self.startswith(pathname, u)
if not u_internal and not show_cross_boundary:
continue
v_internal = self.startswith(pathname, v)
if not v_internal and not show_cross_boundary:
continue
if not (u_internal or v_internal):
continue
newdata = {}
if style:
newdata['style'] = style
if u[0] == 'i' and v[0] == 'i': # show promotion arrows in the right direction
edge = (v, u)
newdata['dir'] = 'back' # show arrows going backwards w/o messing up tree layout
else:
edge = (u, v)
src_indices = data.get('src_indices')
if src_indices is None:
newdata['tooltip'] = f"{edge[0]} -> {edge[1]}"
else:
src_indices = truncate_str(src_indices, max_len=max_width)
newdata['label'] = src_indices
newdata['tooltip'] = f"{edge[0][1]} -> {edge[1][1]}: src_indices: {src_indices}"
yield u, v, newdata
[docs]
def get_drawable_graph(self, pathname='', varname=None, show_cross_boundary=True):
"""
Display the connection graph.
The collection graph is a collection of connection trees. A connection tree has an absolute
output at the root and absolute inputs at the leaves, with any intermediate nodes being
promoted inputs and outputs.
Parameters
----------
pathname : str
The pathname of the system to display.
varname : str or None
Display the connection tree associated with this variable. If None, display the entire
collection graph.
show_cross_boundary : bool
Whether to show cross boundary connections.
"""
G = nx.DiGraph()
G.graph['orientation'] = 'TB'
# Don't scale graph to fit - let it render at natural size for scrolling
G.graph['ratio'] = 'auto'
G.graph['margin'] = '0'
# Set consistent node spacing to prevent nodes from being oversized in small graphs
G.graph['nodesep'] = '0.5'
G.graph['ranksep'] = '0.75'
# Prevent graph from expanding to fill arbitrary width
G.graph['size'] = '20,100!' # Max 20 inches wide, 100 height, don't expand
G.graph['pack'] = 'false' # Don't pack components
G.graph['center'] = 'false' # Don't center the graph
if pathname:
# special handling for cross boundary connections
edges = list(self.drawable_edge_iter(pathname, show_cross_boundary))
draw_nodes = list(self.drawable_node_iter(pathname))
if show_cross_boundary:
visible_nodes = {n for n, _, _ in edges}
visible_nodes.update({n for _, n, _ in edges})
outside_nodes = visible_nodes.difference([n for n, _ in draw_nodes])
nodes = self.nodes
for node in outside_nodes:
node_meta = nodes[node]['attrs']
meta = {
'html_label': self.create_node_label(node),
'tooltip': str((node_meta['pathname'], node_meta['rel_name'])),
'style': 'filled,rounded',
'shape': 'box',
'pathname': node_meta['pathname'],
'rel_name': node_meta['rel_name'],
'fillcolor': GRAPH_COLORS['boundary'],
}
draw_nodes.append((node, meta))
for node, data in draw_nodes:
G.add_node(node, **data)
for u, v, data in edges:
G.add_edge(u, v, **data)
else:
for node, data in self.drawable_node_iter(pathname):
G.add_node(node, **data)
for u, v, data in self.drawable_edge_iter(pathname):
G.add_edge(u, v, **data)
if varname:
varnode = self.find_node(pathname, varname)
G.nodes[varnode]['color'] = GRAPH_COLORS['highlight']
G.nodes[varnode]['penwidth'] = '4' # Thick border
tree = nx.node_connected_component(G.to_undirected(as_view=True), varnode)
G = nx.subgraph(G, tree)
return G
[docs]
def get_dot(self, pathname='', varname=None, show_cross_boundary=True):
"""
Get DOT representation of the connection graph.
Parameters
----------
pathname : str
The pathname of the system to display.
varname : str or None
Display the connection tree associated with this variable. If None, display the entire
collection graph.
show_cross_boundary : bool
Whether to show cross boundary connections.
Returns
-------
str
The DOT representation of the connection graph.
"""
return networkx_to_dot(self.get_drawable_graph(pathname, varname, show_cross_boundary))
[docs]
def display(self, varname=None, pathname='', show_cross_boundary=True, outfile='graph.html'):
"""
Display the connection graph.
Parameters
----------
varname : str or None
Display the connection tree associated with this variable. If None, display the entire
collection graph.
pathname : str
The pathname of the system to display.
show_cross_boundary : bool
Whether to show cross boundary connections.
outfile : str
The name of the file to write the graph to.
"""
if outfile is None:
outfile = 'graph.html'
dot_to_html(self.get_dot(pathname, varname, show_cross_boundary),
outfile=outfile)
[docs]
def print_tree(self, name):
"""
Print the tree of a given variable.
Parameters
----------
name : str
The name of the variable to print the tree of.
"""
if name in self:
node = name
else:
node = ('i', name)
if node not in self:
node = ('o', name)
if node not in self:
raise ValueError(f"Variable '{name}' not found in the graph.")
nodes = self.nodes
for node, depth in self.get_tree_iter(node):
indent = ' ' * depth
meta = nodes[node]['attrs']
mnames = ['units', 'discrete', 'shape']
dismeta = {k: meta[k] for k in mnames if k in meta and meta[k] is not None}
print(f"{indent}{node[1]} {dismeta}")
[docs]
def serve(self, port=None, open_browser=True):
"""
Serve the connection graph web UI.
Parameters
----------
port : int or None
The port to serve the graph on. If None, an unused port will be found.
open_browser : bool
Whether to open the browser to the graph.
"""
from openmdao.visualization.conn_graph_ui import ConnGraphHandler
import socket
def find_unused_port():
"""Find an unused port starting from 8001."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
return s.getsockname()[1]
if port is None:
port = find_unused_port()
def handler(*args, **kwargs):
"""
Return a ConnGraphHandler instance.
Parameters
----------
*args : list
Positional arguments passed to the base handler.
**kwargs : dict
Keyword arguments passed to the base handler.
Returns
-------
ConnGraphHandler
A ConnGraphHandler instance.
"""
return ConnGraphHandler(self, *args, **kwargs)
print(f"🌐 Starting Connection Graph UI on port {port}")
print(f"📱 Open your browser to: http://localhost:{port}")
if open_browser:
def open_browser():
time.sleep(1)
webbrowser.open(f'http://localhost:{port}')
threading.Thread(target=open_browser, daemon=True).start()
try:
with HTTPServer(("", port), handler) as httpd:
print(f"✅ Server running on http://localhost:{port}")
print("Press Ctrl+C to stop")
httpd.serve_forever()
except KeyboardInterrupt:
print("\n🛑 Server stopped")
[docs]
def copy_var_shape(self, from_node, to_node):
"""
Copy shape info from from_node's metadata to to_node's metadata in the graph.
Parameters
----------
from_node : tuple
Tuple containing the IO and name of the variable to copy shape info from.
to_node : tuple
Tuple containing the IO and name of the variable to copy shape info to.
Returns
-------
tuple or None
If the shape of the variable is known, return the shape.
Otherwise, return None.
"""
from_conn_meta = self.nodes[from_node]['attrs']
if from_conn_meta.shape is None:
return
to_conn_meta = self.nodes[to_node]['attrs']
from_io, from_name = from_node
_, to_name = to_node
# is this a connection internal to a component?
internal = to_name.rpartition('.')[0] == from_name.rpartition('.')[0]
if internal: # if internal to a component, src_indices isn't used
fwd = from_io == 'i'
src_inds_list = None
else:
fwd = from_io == 'o'
if fwd:
src_inds_list = to_conn_meta.src_inds_list
else: # rev
src_inds_list = from_conn_meta.src_inds_list
is_full_slice = \
src_inds_list and len(src_inds_list) == 1 and src_inds_list[0].is_full_slice()
if self.comm.size > 1:
dist_from = from_conn_meta.distributed
dist_to = to_conn_meta.distributed
else:
dist_from = dist_to = False
if fwd:
dist_src = dist_from
dist_tgt = dist_to
else:
dist_src = dist_to
dist_tgt = dist_from
fname = _shape_func_map[(dist_src, dist_tgt, fwd)]
return getattr(self, fname)(from_node, to_node, src_inds_list, is_full_slice)
def _get_var_existence(self):
"""
Get the existence of a all continuous variables across all processes.
Returns
-------
dict of the form {'input': np.ndarray, 'output': np.ndarray}
A dictionary of variable existence from the model.
"""
if self._var_existence is None:
all_abs2meta = self._var_allprocs_abs2meta
self._var_existence = {
'input': np.zeros((self.comm.size, len(all_abs2meta['input'])), dtype=bool),
'output': np.zeros((self.comm.size, len(all_abs2meta['output'])), dtype=bool),
}
iproc = self.comm.rank
for io, existence in self._var_existence.items():
abs2meta = self._var_abs2meta[io]
if self.comm.size > 1:
for i, name in enumerate(all_abs2meta[io]):
if name in abs2meta:
existence[iproc, i] = True
row = existence[iproc, :].copy()
self.comm.Allgather(row, existence)
else:
existence[:] = True
return self._var_existence
[docs]
def serial2serial_fwd_shape(self, from_node, to_node, src_inds_list, is_full_slice):
"""
Compute the shape for a serial to serial connection.
Parameters
----------
from_node : tuple of the form ('i' or 'o', name)
The from node.
to_node : tuple of the form ('i' or 'o', name)
The to node.
src_inds_list : list of Indexers
The source indices list.
is_full_slice : bool
If True, the source is a full slice.
Returns
-------
tuple
The shape.
"""
shp = self.nodes[from_node]['attrs'].shape
if src_inds_list:
for inds in src_inds_list:
inds.set_src_shape(shp)
shp = inds.indexed_src_shape
self.nodes[to_node]['attrs'].shape = shp
return shp
[docs]
def serial2serial_rev_shape(self, from_node, to_node, src_inds_list, is_full_slice):
"""
Compute the shape for a serial to serial reverse connection.
Parameters
----------
from_node : tuple of the form ('i' or 'o', name)
The from node.
to_node : tuple of the form ('i' or 'o', name)
The to node.
src_inds_list : list of Indexers
The source indices list.
is_full_slice : bool
If True, the source is a full slice.
Returns
-------
tuple
The shape.
"""
if not src_inds_list or is_full_slice:
shp = self.nodes[from_node]['attrs'].shape
else:
self._collect_error(f"Input '{from_node[1]}' has src_indices so the shape "
f"of connected output '{to_node[1]}' cannot be "
"determined.")
return
self.nodes[to_node]['attrs'].shape = shp
return shp
[docs]
def serial2dist_fwd_shape(self, from_node, to_node, src_inds_list, is_full_slice):
"""
Compute the shape for a serial to distributed forward connection.
Parameters
----------
from_node : tuple of the form ('i' or 'o', name)
The from node.
to_node : tuple of the form ('i' or 'o', name)
The to node.
src_inds_list : list of Indexers
The source indices list.
is_full_slice : bool
If True, the source is a full slice.
Returns
-------
tuple
The shape.
"""
from_shape = self.nodes[from_node]['attrs'].shape
existence = self._get_var_existence()
exist_outs = existence['output'][:, self._var_allprocs_abs2idx[from_node[1]]]
exist_ins = existence['input'][:, self._var_allprocs_abs2idx[to_node[1]]]
to_meta = self.nodes[to_node]['attrs']
num_exist = np.count_nonzero(exist_outs)
if not src_inds_list:
shape = from_shape
else:
if len(from_shape) <= 1:
global_src_shape = (num_exist * shape_to_len(from_shape),)
else:
shapelst = list(from_shape)
# stack the shapes across procs
first_dim = shapelst[0] * num_exist
shapelst[0] = first_dim
global_src_shape = tuple(shapelst)
shp = global_src_shape
for inds in src_inds_list:
inds.set_src_shape(shp)
shp = inds.indexed_src_shape
shape = shp
size = shape_to_len(shape)
sizes = np.zeros(self.comm.size, dtype=int)
sizes[exist_ins] = size
dist_shapes = self.get_dist_shapes()
dist_sizes = self.get_dist_sizes()
dist_sizes[to_node] = sizes
dist_shapes[to_node] = [shape if exist_ins[i] else None for i in range(self.comm.size)]
to_meta.shape = shape
nexist = len([i for i in exist_ins if i])
if len(shape) > 1:
firstdim = shape[0] * nexist
gshape = (firstdim,) + shape[1:]
else:
gshape = (size * nexist,)
to_meta.global_shape = gshape
return shape
[docs]
def serial2dist_rev_shape(self, from_node, to_node, src_inds_list, is_full_slice):
"""
Compute the shape for a serial to distributed reverse connection.
Parameters
----------
from_node : tuple of the form ('i' or 'o', name)
The from node.
to_node : tuple of the form ('i' or 'o', name)
The to node.
src_inds_list : list of Indexers
The source indices list.
is_full_slice : bool
If True, the source is a full slice.
Returns
-------
tuple
The shape.
"""
# serial_out <-- dist_in
dshapes = self.get_dist_shapes(from_node)
dshapes = [s for s in dshapes if s is not None and shape_to_len(s) > 0]
if len(dshapes) >= 1:
shape0 = dshapes[0]
for ds in dshapes:
if ds != shape0:
to_io = 'input' if to_node[0] == 'i' else 'output'
from_io = 'input' if from_node[0] == 'i' else 'output'
dshapes = [_strip_np(s) for s in dshapes]
self._collect_error(
f"{self.msginfo}: dynamic sizing of non-distributed {to_io} '{to_node[1]}' "
f"from distributed {from_io} '{from_node[1]}' is not supported because not "
f"all '{from_node[1]}' ranks are the same shape "
f"(shapes={dshapes}).", ident=(from_node, to_node))
return
self.nodes[to_node]['attrs'].shape = shape0
return shape0
[docs]
def dist2serial_fwd_shape(self, from_node, to_node, src_inds_list, is_full_slice):
"""
Compute the shape for a distributed to serial forward connection.
Parameters
----------
from_node : tuple of the form ('i' or 'o', name)
The from node.
to_node : tuple of the form ('i' or 'o', name)
The to node.
src_inds_list : list of Indexers
The source indices list.
is_full_slice : bool
If True, the source is a full slice.
Returns
-------
tuple
The shape.
"""
if src_inds_list:
shp = self.compute_global_shape(from_node)
for src_indices in src_inds_list:
src_indices.set_src_shape(shp)
shp = src_indices.indexed_src_shape
self.nodes[to_node]['attrs'].shape = shp
return shp
else:
# We don't allow this case because
# serial variables must have the same value on all procs and the only way
# this is possible is if the src_indices on each proc are identical, but that's not
# possible if we assume 'always local' transfer (see POEM 46).
self._collect_error(
f"{self.msginfo}: dynamic sizing of non-distributed input '{to_node[1]}' "
f"from distributed output '{from_node[1]}' without src_indices is not "
"supported.")
[docs]
def dist2serial_rev_shape(self, from_node, to_node, src_inds_list, is_full_slice):
"""
Compute the shape for a distributed to serial reverse connection.
Parameters
----------
from_node : tuple of the form ('i' or 'o', name)
The from node.
to_node : tuple of the form ('i' or 'o', name)
The to node.
src_inds_list : list of Indexers
The source indices list.
is_full_slice : bool
If True, the source is a full slice.
Returns
-------
tuple
The shape.
"""
# dist_out <-- serial_in
if is_full_slice:
abs2idx = self._var_allprocs_abs2idx
# serial input is using full slice here, so contains the full
# distributed value of the distributed output (and serial input will
# have the same value on all procs).
# dist input may not exist on all procs, so distribute the serial
# entries across only the procs where the dist output exists.
exist_procs = self._get_var_existence()['output'][:, abs2idx[to_node[1]]]
split_num = np.count_nonzero(exist_procs)
dist_shapes = self.get_dist_shapes()
dist_sizes = self.get_dist_sizes()
from_shape = self.nodes[from_node]['attrs'].shape
sz, _ = evenly_distrib_idxs(split_num, shape_to_len(from_shape))
sizes = np.zeros(self.comm.size, dtype=int)
sizes[exist_procs] = sz
dist_sizes[to_node] = sizes.copy()
to_meta = self.nodes[to_node]['attrs']
if len(from_shape) > 1:
if from_shape[0] != self.comm.size:
self._collect_error(f"Serial input '{from_node[1]}' has shape "
f"{from_shape} but output '{to_node[1]}' "
f"is distributed over {self.comm.size} "
f"procs and {from_shape[0]} != "
f"{self.comm.size}.")
return
else:
dist_shapes[to_node] = [from_shape[1:]] * self.comm.size
shp = to_meta.shape = from_shape[1:]
else:
dist_shapes[to_node] = [(s,) for s in sizes]
shp = to_meta.shape = (sizes[sizes != 0][0],)
to_meta.global_shape = from_shape
return shp
else:
self._collect_error(f"Input '{from_node[1]}' has src_indices so the shape "
f"of connected output '{to_node[1]}' cannot be "
"determined.")
[docs]
def dist2dist_fwd_shape(self, from_node, to_node, src_inds_list, is_full_slice):
"""
Compute the shape for a distributed to distributed forward connection.
Parameters
----------
from_node : tuple of the form ('i' or 'o', name)
The from node.
to_node : tuple of the form ('i' or 'o', name)
The to node.
src_inds_list : list of Indexers
The source indices list.
is_full_slice : bool
If True, the source is a full slice.
Returns
-------
tuple
The shape.
"""
if is_full_slice:
self._collect_error(f"Using a full slice [:] as src_indices between"
f" distributed variables '{from_node[1]}' and "
f"'{to_node[1]}' is invalid.")
return
from_conn_meta = self.nodes[from_node]['attrs']
to_conn_meta = self.nodes[to_node]['attrs']
dist_shapes = self.get_dist_shapes()
dist_sizes = self.get_dist_sizes()
if not src_inds_list:
shp = from_conn_meta.shape
to_conn_meta.shape = shp
dist_shapes[to_node] = dist_shapes[from_node].copy()
dist_sizes[to_node] = dist_sizes[from_node].copy()
to_conn_meta.global_shape = from_conn_meta.global_shape
else:
shp = self.compute_global_shape(from_node)
for src_indices in src_inds_list:
src_indices.set_src_shape(shp)
shp = src_indices.indexed_src_shape
to_conn_meta.shape = shp
dist_shapes[to_node] = self.comm.allgather(shp)
dist_sizes[to_node] = np.array([shape_to_len(s) for s in dist_shapes[to_node]])
to_conn_meta.global_shape = self.compute_global_shape(to_node)
return shp
[docs]
def dist2dist_rev_shape(self, from_node, to_node, src_inds_list, is_full_slice):
"""
Compute the shape for a distributed to distributed reverse connection.
Parameters
----------
from_node : tuple of the form ('i' or 'o', name)
The from node.
to_node : tuple of the form ('i' or 'o', name)
The to node.
src_inds_list : list of Indexers
The source indices list.
is_full_slice : bool
If True, the source is a full slice.
Returns
-------
tuple
The shape.
"""
if is_full_slice:
self._collect_error(f"Using a full slice [:] as src_indices between"
f" distributed variables '{from_node[1]}' and "
f"'{to_node[1]}' is invalid.")
elif not src_inds_list:
shp = self.nodes[from_node]['attrs'].shape
to_conn_meta = self.nodes[to_node]['attrs']
to_conn_meta.shape = shp
dist_shapes = self.get_dist_shapes()
dist_sizes = self.get_dist_sizes()
dist_shapes[to_node] = dist_shapes[from_node].copy()
dist_sizes[to_node] = dist_sizes[from_node].copy()
to_conn_meta.global_shape = self.compute_global_shape(to_node)
return shp
else:
self._collect_error(f"Input '{from_node[1]}' has src_indices so the shape "
f"of connected output '{to_node[1]}' cannot be "
"determined.")
[docs]
def get_conn_tree_graph(self, node, inputs=True, outputs=True):
"""
Get the connection tree graph for a given node.
Parameters
----------
node : tuple of the form ('i' or 'o', name)
The node to get the connection tree for.
inputs : bool, optional
Whether to include input nodes in the graph.
outputs : bool, optional
Whether to include output nodes in the graph.
Returns
-------
nx.DiGraph
The connection tree graph containing all nodes connected to the given node.
"""
nodes = nx.node_connected_component(self.to_undirected(as_view=True), node)
if not inputs:
nodes = [n for n in nodes if n[0] == 'o']
if not outputs:
nodes = [n for n in nodes if n[0] == 'i']
return self.subgraph(nodes)