"""
Tools for working with code.
"""
import sys
import os
import io
import inspect
import ast
import textwrap
import importlib
from types import LambdaType
from collections import defaultdict, OrderedDict
from tokenize import NAME, tokenize, untokenize
import networkx as nx
def _get_long_name(node):
# If the node is an Attribute or Name node that is composed
# only of other Attribute or Name nodes, then return the full
# dotted name for this node. Otherwise, i.e., if this node
# contains Subscripts or Calls, return None.
if isinstance(node, ast.Name):
return node.id
elif not isinstance(node, ast.Attribute):
return None
val = node.value
parts = [node.attr]
while True:
if isinstance(val, ast.Attribute):
parts.append(val.attr)
val = val.value
elif isinstance(val, ast.Name):
parts.append(val.id)
break
else: # it's more than just a simple dotted name
return None
return '.'.join(parts[::-1])
class _SelfCallCollector(ast.NodeVisitor):
"""
An ast.NodeVisitor that records calls to self.* methods.
"""
def __init__(self, class_):
super().__init__()
self.self_calls = defaultdict(list)
self.class_ = class_
self.mro = inspect.getmro(class_)
self.mro_names = set([c.__name__ for c in self.mro])
def visit_Call(self, node): # (func, args, keywords, starargs, kwargs)
fncname = _get_long_name(node.func)
class_ = self.class_
if fncname is not None:
if fncname.startswith('self.') and len(fncname.split('.')) == 2:
shortfnc = fncname.split('.')[1]
if shortfnc not in self.self_calls[class_]:
self.self_calls[class_].append(shortfnc)
for arg in node.args:
self.visit(arg)
# check for Class.func(inst) form for base class method call
elif (len(fncname.split('.')) == 2 and fncname.split('.')[0] in self.mro_names and
node.args and isinstance(node.args[0], ast.Name) and node.args[0].id == 'self'):
cname, func = fncname.split('.')
for c in self.mro:
if c.__name__ == cname:
sub_mro = inspect.getmro(c)
for sub_c in sub_mro:
if func in sub_c.__dict__:
c = sub_c
break
if func not in self.self_calls[c]:
self.self_calls[c].append(func)
for arg in node.args:
self.visit(arg)
break
else:
self.generic_visit(node)
else:
self.generic_visit(node)
# check for super() call
elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Call):
callnode = node.func.value
n = _get_long_name(callnode.func)
# if this is a 'super' call, get the base of the specified class
if n == 'super': # this only works for a single call level
if len(callnode.args) == 0:
sup_0 = self.mro[0].__name__
visit_super = True
else:
sup_1 = _get_long_name(callnode.args[1])
sup_0 = _get_long_name(callnode.args[0])
visit_super = (sup_1 == 'self' and
sup_0 is not None and len(sup_0.split('.')) == 1)
if visit_super:
for i, c in enumerate(self.mro[:-1]):
if sup_0 == c.__name__:
# we need super of the specified class
sub_mro = inspect.getmro(c)
for sub_c in sub_mro:
if sub_c is not c:
c = sub_c
break
fn = node.func.attr
if fn not in self.self_calls[c]:
self.self_calls[c].append(fn)
for arg in node.args:
self.visit(arg)
break
else:
self.generic_visit(node)
else:
self.generic_visit(node)
else:
self.generic_visit(node)
def _find_owning_class(mro, func_name):
"""
Return the full funcname and class where the function is first found in the class MRO.
"""
# TODO: this won't work for classes with __slots__
for c in mro:
if func_name in c.__dict__:
return '.'.join((c.__name__, func_name)), c
return None, None
def _get_nested_calls(starting_class, class_, func_name, parent, graph, seen):
"""
Parse the AST of the given method and all 'self' methods it calls and record owning classes.
"""
func = getattr(class_, func_name)
src = inspect.getsource(func)
dedented_src = textwrap.dedent(src)
node = ast.parse(dedented_src, mode='exec')
visitor = _SelfCallCollector(starting_class)
visitor.visit(node)
seen.add('.'.join((class_.__name__, func_name)))
# now find the actual owning class for each call
for klass, funcset in visitor.self_calls.items():
mro = inspect.getmro(klass)
for f in funcset:
full, c = _find_owning_class(mro, f)
if full is not None:
graph.add_edge(parent, full)
if full not in seen:
_get_nested_calls(starting_class, c, f, full, graph, seen)
[docs]
def get_nested_calls(class_, method_name, stream=sys.stdout):
"""
Display the call tree for the specified class method and all 'self' class methods it calls.
Parameters
----------
class_ : class
The starting class.
method_name : str
The name of the class method.
stream : file-like
The output stream where output will be displayed.
Returns
-------
networkx.DiGraph
A graph containing edges from methods to their sub-methods.
"""
# moved this class def in here to keep the numpy doc scraper from barfing due to
# stuff in nx.DiGraph.
class OrderedDiGraph(nx.DiGraph):
"""
A DiGraph using OrderedDicts for internal storage.
"""
node_dict_factory = OrderedDict
adjlist_dict_factory = OrderedDict
edge_attr_dict_factory = OrderedDict
graph = OrderedDiGraph()
seen = set()
top = object()
full, klass = _find_owning_class(inspect.getmro(class_), method_name)
if full is None:
print("Can't find function '%s' in class '%s'." % (method_name, class_.__name__))
else:
graph.add_edge(top, full)
parent = full
_get_nested_calls(class_, klass, method_name, parent, graph, seen)
if graph and stream is not None:
seen = set([top])
stack = [(0, iter(graph[top]))]
while stack:
depth, children = stack[-1]
try:
n = next(children)
stream.write("%s%s\n" % (' ' * depth, n))
if n not in seen:
stack.append((depth + 1, iter(graph[n])))
seen.add(n)
except StopIteration:
stack.pop()
return graph
def _calltree_setup_parser(parser):
"""
Set up the command line options for the 'openmdao call_tree' command line tool.
"""
parser.add_argument('method_path', nargs=1,
help='Full module path to desired class method, e.g., '
'"openmdao.components.exec_comp.ExecComp.setup".')
parser.add_argument('-o', '--outfile', action='store', dest='outfile',
default='stdout', help='Output file. Defaults to stdout.')
def _calltree_exec(options, user_args):
"""
Process command line args and call get_nested_calls on the specified class method.
"""
parts = options.method_path[0].split('.')
if len(parts) < 3:
raise RuntimeError("You must supply the full module path to the function, "
"for example: openmdao.api.Group._setup.")
class_name = parts[-2]
func_name = parts[-1]
modpath = '.'.join(parts[:-2])
old_syspath = sys.path[:]
sys.path.append(os.getcwd())
try:
mod = importlib.import_module(modpath)
klass = getattr(mod, class_name)
stream_map = {'stdout': sys.stdout, 'stderr': sys.stderr}
stream = stream_map.get(options.outfile)
if stream is None:
stream = open(options.outfile, 'w')
get_nested_calls(klass, func_name, stream)
finally:
sys.path = old_syspath
def _target_iter(targets):
for target in targets:
if isinstance(target, ast.Tuple):
for t in target.elts:
yield t
else:
yield target
class _AttrCollector(ast.NodeVisitor):
"""
An ast.NodeVisitor that records class attribute names.
"""
def __init__(self, class_dict):
super().__init__()
self.class_dict = class_dict
self.class_stack = []
self.func_stack = []
self.names = None
self.decnames = None
def get_attributes(self):
return self.class_dict
def visit_ClassDef(self, node):
full_name = '.'.join(self.class_stack[:] + [node.name])
self.class_stack.append(full_name)
self.class_dict[full_name] = set()
for stmt in node.body:
self.visit(stmt)
self.class_stack.pop()
if self.func_stack: # ignore classes nested in functs
del self.class_dict[full_name]
def visit_FunctionDef(self, node):
self.func_stack.append(node.name)
for stmt in node.body:
self.visit(stmt)
self.func_stack.pop()
if self.class_stack:
# see if this is a property, and if so, treat as an attribute
for dec in node.decorator_list:
self.decnames = []
self.visit(dec)
if len(self.decnames) == 1 and self.decnames[0] == 'property':
self.class_dict[self.class_stack[-1]].add(node.name)
self.decnames = None
def visit_Assign(self, node):
if self.class_stack:
for t in _target_iter(node.targets):
self.names = []
self.visit(t)
if len(self.names) > 1 and self.names[0] == 'self':
self.class_dict[self.class_stack[-1]].add(self.names[1])
self.names = None
def visit_Attribute(self, node):
if self.names is not None:
self.visit(node.value)
self.names.append(node.attr)
def visit_Name(self, node):
if self.names is not None:
self.names.append(node.id)
elif self.decnames is not None:
self.decnames.append(node.id)
[docs]
def get_class_attributes(fname, class_dict=None):
"""
Find all referenced attributes in all classes defined in the given file.
Parameters
----------
fname : str
File name.
class_dict : dict or None
Dict mapping class names to attribute names.
Returns
-------
dict
The dict maps class name to a set of attribute names.
"""
if class_dict is None:
class_dict = {}
with open(fname, 'r') as f:
source = f.read()
node = ast.parse(source, mode='exec')
visitor = _AttrCollector(class_dict)
visitor.visit(node)
return visitor.get_attributes()
def _get_return_name(node):
return node.id if isinstance(node, ast.Name) else None
def _get_return_names(outs):
"""
Return a list of (name or None) for each return value.
If there are multiple returns that differ by name or number of return values, an exception
will be raised. If one entry in one return list has a name and another is None, the name
will take precedence and no exception will be raised.
Returns
-------
list
The list of return names. Some entries will be None if there was no simple name
associated with a given return value.
"""
if len(outs) == 0:
return []
if len(outs) == 1:
return outs[0]
names = outs[0].copy()
length = len(names)
for lst in outs[1:]:
if len(lst) != length:
raise RuntimeError("Function has multiple return statements with differing numbers "
"of return values.")
for i, (name, newname) in enumerate(zip(names, lst)):
if name is None:
names[i] = newname
elif newname is not None and name != newname:
raise RuntimeError("Function has multiple return statements with different "
f"return value names of {sorted((name, newname))} for "
f"return value {i}.")
return names
[docs]
def get_return_names(func):
"""
Return the names of the variables returned by the given function.
Returns None for any return values that aren't a simple name.
Parameters
----------
func : function
The function to be examined.
Returns
-------
list
The names of the variables returned by the given function.
"""
class _FuncRetNameCollector(ast.NodeVisitor):
"""
An ast.NodeVisitor that records return value names.
Attributes
----------
_ret_infos : list
List containing one entry for each return statement, with each entry containing a list
of name (or None) for each function return value.
"""
def __init__(self, func):
super().__init__()
self._ret_infos = []
self.visit(ast.parse(textwrap.dedent(inspect.getsource(func)), mode='exec'))
def get_return_names(self):
"""
Return a list of (name or None) for each return value.
If there are multiple returns that differ by name or number of return values, an
exception will be raised. If one entry in one return list has a name and another is
None, the name will take precedence and no exception will be raised.
Returns
-------
list
The list of return names. Some entries will be None if there was no simple name
associated with a given return value.
"""
return _get_return_names(self._ret_infos)
def visit_Return(self, node):
"""
Visit a Return node.
Parameters
----------
node : ASTnode
The return node being visited.
"""
self._ret_infos.append([])
if isinstance(node.value, ast.Tuple):
for n in node.value.elts:
self._ret_infos[-1].append(_get_return_name(n))
else:
self._ret_infos[-1].append(_get_return_name(node.value))
return _FuncRetNameCollector(func).get_return_names()
class _FuncGrapher(ast.NodeVisitor):
"""
An ast.NodeVisitor that builds a graph between a function's inputs and outputs.
"""
def __init__(self, node):
super().__init__()
self.rhs = []
self.lhs = []
self.names = None
self.graph = nx.DiGraph()
self.outs = []
self.fstack = []
self.visit(node)
def _update_graph(self):
for inp in self.rhs:
for out in self.lhs:
self.graph.add_edge(inp, out)
self.lhs = []
self.rhs = []
def visit_FunctionDef(self, node):
if self.fstack:
raise RuntimeError("Function contains nested functions, which are not supported.")
self.fstack.append(node)
for stmt in node.body:
self.visit(stmt)
self.fstack.pop()
def visit_Assign(self, node):
self.names = self.lhs
for t in _target_iter(node.targets):
self.visit(t)
self.names = self.rhs
self.visit(node.value)
self.names = None
self._update_graph()
def visit_Attribute(self, node):
pass # skip any Name nodes that are part of an Attribute node
def visit_Call(self, node):
for arg in node.args:
self.visit(arg)
def visit_Name(self, node):
if self.names is not None:
self.names.append(node.id)
def visit_Return(self, node):
self.outs.append([])
if isinstance(node.value, ast.Tuple):
it = enumerate(node.value.elts)
else:
it = [(0, node.value)]
for i, n in it:
self.lhs = [f"@out{i}"]
self.rhs = []
self.names = self.rhs
self.visit(n)
self._update_graph()
self.outs[-1].append(_get_return_name(n))
[docs]
def get_func_graph(func, outnames=None, display=False):
"""
Generate a graph between a function's inputs and outputs.
Uses the AST to analyze the function and build a graph between inputs and outputs, so the
function source must be available.
Parameters
----------
func : Callable
The function to be analyzed.
outnames : list or None
The list of expected output variable names.
display : bool
If True, display the graph using pydot.
Returns
-------
networkx.DiGraph
A graph containing edges from inputs to outputs.
"""
node = ast.parse(textwrap.dedent(inspect.getsource(func)), mode='exec')
visitor = _FuncGrapher(node)
retnames = _get_return_names(visitor.outs)
inputs = set(inspect.signature(func).parameters)
# check vs outnames
if outnames is not None:
if len(retnames) != len(outnames):
raise RuntimeError("Number of return values in function does not match number of "
f"expected return names. ({outnames}) != ({retnames})")
for ret, name in zip(retnames, outnames):
if ret is not None and ret != name:
raise RuntimeError(f"Return value name '{name}' in function does not match "
f"expected name '{ret}.")
else:
outnames = []
for i, ret in enumerate(retnames):
if ret is None or ret in inputs:
outnames.append(f'out{len(outnames)}')
else:
outnames.append(ret)
mapping = {f'@out{i}': name for i, name in enumerate(outnames)}
visitor.graph = nx.relabel_nodes(visitor.graph, mapping)
to_remove = [e for e in visitor.graph.edges() if e[0] == e[1]]
visitor.graph.remove_edges_from(to_remove)
# make sure all outputs exist as nodes in graph, even if they have no incoming edges
visitor.graph.add_nodes_from(outnames)
visitor.graph.graph['inputs'] = inputs
visitor.graph.graph['outputs'] = outnames
if display:
# show the function graph visually
from openmdao.visualization.graph_viewer import write_graph, _to_pydot_graph
write_graph(_to_pydot_graph(visitor.graph))
return visitor.graph
[docs]
def get_partials_deps(func, outputs=None):
"""
Generate tuples of the form (output, input) for the given function.
Only tuples where the output depends on the input are yielded. This can be used to
determine which partials need to be declared.
Note that currently the function grapher doesn't recurse into functions and assumes that all
outputs of a sub-function are dependent on all inputs to that function. This may lead to
a conservative estimate of partials that need to be declared.
Parameters
----------
func : Callable
The function to be analyzed.
outputs : list or None
The list of output variable names.
Yields
------
tuple
A tuple of the form (output, input).
"""
graph = get_func_graph(func, outputs)
outs = graph.graph['outputs']
successors = graph.successors
for start in graph.graph['inputs']:
visited = set([start])
stack = [(start, successors(start))]
while stack:
_, succs = stack[-1]
for succ in succs:
if succ not in visited:
visited.add(succ)
if succ in outs:
yield succ, start
stack.append((succ, successors(succ)))
break
else:
stack.pop()
[docs]
def block_filter(tokiter, blocks_to_remove, block_start_tok):
"""
Remove blocks of code from a stream of tokens.
Blocks are removed based on indentation level. If a block's name matches one in
blocks_to_remove, all non-blank lines where the first token is indented to a greater level
than the block start token are removed.
Parameters
----------
tokiter : iterator
Iterator of tokens.
blocks_to_remove : set
Set of block names to remove.
block_start_tok : str
The name of the block start token, e.g., 'def' or 'class'.
Yields
------
tuple
The next token in the stream, unless it is part of a block that
should be removed.
"""
indent = None
save = []
for tok in tokiter:
toktype, tokval, start, _, _ = tok
tokcol = start[1]
if save: # we're on block start line after block start token
if toktype == NAME and tokval not in blocks_to_remove:
indent = None
yield from save
yield tok
save = []
continue
elif toktype == NAME and tokval == block_start_tok: # block start line
indent = tokcol
save.append(tok) # we might need to emit this token if block doesn't match
continue
elif indent is not None:
if tokcol > indent or not tokval.strip(): # skip lines that are indented or blank
continue
else: # block is done
indent = None
yield tok
[docs]
def find_block_start(srccode, block_name, block_start_tok):
"""
Find the start of a block of code.
Parameters
----------
srccode : str
Source code to search for block.
block_name : str
The name of the block to find.
block_start_tok : str
The name of the block start token, e.g., 'def' or 'class'.
Returns
-------
tuple
A tuple of the form (line number, column number, block start line) or (None, None, None) if
the block was not found.
"""
tokiter = tokenize(io.BytesIO(srccode.encode('utf-8')).readline)
for tok in tokiter:
toktype, tokval, start, _, _ = tok
if toktype == NAME and tokval == block_start_tok: # block start line
try:
nxt = next(tokiter)
except StopIteration:
return (None, None, None)
ntoktype, ntokval, _, _, _ = nxt
if ntoktype == NAME and ntokval == block_name:
return start[0], start[1], tok[-1]
return (None, None, None)
[docs]
def remove_src_blocks(srccode, names, block_start_tok):
"""
Remove blocks from a piece of source code.
Parameters
----------
srccode : str
The source code.
names : list of str
List of blocks to be removed.
block_start_tok : str
The name of the block start token, e.g., 'def' or 'class'.
Returns
-------
str
The modified source code.
"""
if not names:
return srccode
return untokenize(block_filter(tokenize(io.BytesIO(srccode.encode('utf-8')).readline),
set(names), block_start_tok=block_start_tok)).decode()
[docs]
def replace_src_block(srccode, block_name, new_block, block_start_tok):
"""
Replace a block in a piece of source code.
Parameters
----------
srccode : str
The source code.
block_name : str
The name of the block to be replaced.
new_block : str
The replacement block.
block_start_tok : str
The name of the block start token, e.g., 'def' or 'class'.
Returns
-------
str
The modified source code.
"""
linenum, _, _ = find_block_start(srccode, block_name, block_start_tok)
if linenum is None:
raise RuntimeError(f"Block '{block_start_tok} {block_name}' not found in source code.")
stream = io.StringIO(srccode)
lines = srccode.splitlines()
linenum -= 1 # i is zero indexed, so adjust linenum to be zero indexed
for i, line in enumerate(lines):
if i == linenum:
# insert new block
for subline in new_block.splitlines():
print(subline, file=stream)
print('', file=stream)
print(line, file=stream)
newsrc = stream.getvalue()
# now remove the old block
return untokenize(block_filter(tokenize(io.BytesIO(newsrc.encode('utf-8')).readline),
{block_name}, block_start_tok=block_start_tok)).decode()
[docs]
def is_lambda(f):
"""
Return True if the given function is a lambda function.
Parameters
----------
f : function
The function to check.
Returns
-------
bool
True if the given function is a lambda function.
"""
return isinstance(f, LambdaType) and f.__name__ == "<lambda>"
[docs]
class LambdaPickleWrapper(object):
"""
A wrapper for a lambda function that allows it to be pickled.
Parameters
----------
lambda_func : function
The lambda function to be wrapped.
Attributes
----------
_func : function
The lambda function.
_src : str
The isolated source of the lambda function.
"""
[docs]
def __init__(self, lambda_func):
"""
Initialize the wrapper.
Parameters
----------
lambda_func : function
The lambda function to be wrapped.
"""
self._func = lambda_func
self._src = None
def __call__(self, *args, **kwargs):
"""
Call the lambda function.
Parameters
----------
*args : list
Positional arguments.
**kwargs : dict
Keyword arguments.
Returns
-------
object
The result of the lambda function.
"""
return self._func(*args, **kwargs)
def __getstate__(self):
"""
Return the state of this object for pickling.
The lambda function is converted to a string for pickling.
Returns
-------
dict
The state of this object.
"""
state = self.__dict__.copy()
state['_func'] = self._getsrc()
return state
def __setstate__(self, state):
"""
Restore the state of this object after pickling.
Parameters
----------
state : dict
The state of this object.
"""
self.__dict__.update(state)
self._func = eval(state['_func']) # nosec
def _getsrc(self):
if self._src is None:
self._src = _LambdaSrcFinder(self._func).src
if self._src is None:
raise RuntimeError("The fix for pickling lambda functions only works for python "
"3.9 and above. Try updating to a newer python version or "
"replacing the lambda with a regular function.")
return self._src
class _LambdaSrcFinder(ast.NodeVisitor):
"""
Given a lambda function, isolate the lambda function source from any surrounding code.
"""
def __init__(self, func):
super().__init__()
self.src = None
# note that inspect.getsource gives the source for the line that contains the lambda
# function, so we have to isolate the lambda function itself
self.visit(ast.parse(textwrap.dedent(inspect.getsource(func)), filename='<string>'))
def visit_Lambda(self, node):
if self.src is not None:
# it's possible to have multiple lambdas defined on the same line, so raise an error
# if we find more than one.
raise RuntimeError("Only one lambda function is allowed per line when using "
"_LambdaSrcFinder.")
try:
self.src = ast.unparse(node)
except AttributeError:
# ast.unparse was added in python 3.9
self.src = None
if __name__ == '__main__':
import pprint
pprint.pprint(get_class_attributes(__file__))