Source code for openmdao.utils.code_utils

"""
Tools for working with code.
"""

import sys
import os
import inspect
import ast
import textwrap
import importlib
from collections import defaultdict, OrderedDict

import networkx as nx


def _get_long_name(node):
    # If the node is an Attribute or Name node that is composed
    # only of other Attribute or Name nodes, then return the full
    # dotted name for this node. Otherwise, i.e., if this node
    # contains Subscripts or Calls, return None.
    if isinstance(node, ast.Name):
        return node.id
    elif not isinstance(node, ast.Attribute):
        return None
    val = node.value
    parts = [node.attr]
    while True:
        if isinstance(val, ast.Attribute):
            parts.append(val.attr)
            val = val.value
        elif isinstance(val, ast.Name):
            parts.append(val.id)
            break
        else:  # it's more than just a simple dotted name
            return None
    return '.'.join(parts[::-1])


class _SelfCallCollector(ast.NodeVisitor):
    """
    An ast.NodeVisitor that records calls to self.* methods.
    """

    def __init__(self, class_):
        super().__init__()
        self.self_calls = defaultdict(list)
        self.class_ = class_
        self.mro = inspect.getmro(class_)
        self.mro_names = set([c.__name__ for c in self.mro])

    def visit_Call(self, node):  # (func, args, keywords, starargs, kwargs)
        fncname = _get_long_name(node.func)
        class_ = self.class_
        if fncname is not None:
            if fncname.startswith('self.') and len(fncname.split('.')) == 2:
                shortfnc = fncname.split('.')[1]
                if shortfnc not in self.self_calls[class_]:
                    self.self_calls[class_].append(shortfnc)
                for arg in node.args:
                    self.visit(arg)
            # check for Class.func(inst) form for base class method call
            elif (len(fncname.split('.')) == 2 and fncname.split('.')[0] in self.mro_names and
                  node.args and isinstance(node.args[0], ast.Name) and node.args[0].id == 'self'):
                cname, func = fncname.split('.')
                for c in self.mro:
                    if c.__name__ == cname:
                        sub_mro = inspect.getmro(c)
                        for sub_c in sub_mro:
                            if func in sub_c.__dict__:
                                c = sub_c
                                break
                        if func not in self.self_calls[c]:
                            self.self_calls[c].append(func)
                        for arg in node.args:
                            self.visit(arg)
                        break
                else:
                    self.generic_visit(node)
            else:
                self.generic_visit(node)
        # check for super() call
        elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Call):
            callnode = node.func.value
            n = _get_long_name(callnode.func)
            # if this is a 'super' call, get the base of the specified class
            if n == 'super':  # this only works for a single call level
                if len(callnode.args) == 0:
                    sup_0 = self.mro[0].__name__
                    visit_super = True
                else:
                    sup_1 = _get_long_name(callnode.args[1])
                    sup_0 = _get_long_name(callnode.args[0])
                    visit_super = (sup_1 == 'self' and
                                   sup_0 is not None and len(sup_0.split('.')) == 1)

                if visit_super:
                    for i, c in enumerate(self.mro[:-1]):
                        if sup_0 == c.__name__:
                            # we need super of the specified class
                            sub_mro = inspect.getmro(c)
                            for sub_c in sub_mro:
                                if sub_c is not c:
                                    c = sub_c
                                    break
                            fn = node.func.attr
                            if fn not in self.self_calls[c]:
                                self.self_calls[c].append(fn)
                            for arg in node.args:
                                self.visit(arg)
                            break
                    else:
                        self.generic_visit(node)
            else:
                self.generic_visit(node)
        else:
            self.generic_visit(node)


def _find_owning_class(mro, func_name):
    """
    Return the full funcname and class where the function is first found in the class MRO.
    """
    # TODO: this won't work for classes with __slots__

    for c in mro:
        if func_name in c.__dict__:
            return '.'.join((c.__name__, func_name)), c

    return None, None


def _get_nested_calls(starting_class, class_, func_name, parent, graph, seen):
    """
    Parse the AST of the given method and all 'self' methods it calls and record owning classes.
    """
    func = getattr(class_, func_name)
    src = inspect.getsource(func)
    dedented_src = textwrap.dedent(src)

    node = ast.parse(dedented_src, mode='exec')
    visitor = _SelfCallCollector(starting_class)
    visitor.visit(node)

    seen.add('.'.join((class_.__name__, func_name)))

    # now find the actual owning class for each call
    for klass, funcset in visitor.self_calls.items():
        mro = inspect.getmro(klass)
        for f in funcset:
            full, c = _find_owning_class(mro, f)
            if full is not None:
                graph.add_edge(parent, full)
                if full not in seen:
                    _get_nested_calls(starting_class, c, f, full, graph, seen)


[docs]def get_nested_calls(class_, method_name, stream=sys.stdout): """ Display the call tree for the specified class method and all 'self' class methods it calls. Parameters ---------- class_ : class The starting class. method_name : str The name of the class method. stream : file-like The output stream where output will be displayed. Returns ------- networkx.DiGraph A graph containing edges from methods to their sub-methods. """ # moved this class def in here to keep the numpy doc scraper from barfing due to # stuff in nx.DiGraph. class OrderedDiGraph(nx.DiGraph): """ A DiGraph using OrderedDicts for internal storage. """ node_dict_factory = OrderedDict adjlist_dict_factory = OrderedDict edge_attr_dict_factory = OrderedDict graph = OrderedDiGraph() seen = set() top = object() full, klass = _find_owning_class(inspect.getmro(class_), method_name) if full is None: print("Can't find function '%s' in class '%s'." % (method_name, class_.__name__)) else: graph.add_edge(top, full) parent = full _get_nested_calls(class_, klass, method_name, parent, graph, seen) if graph and stream is not None: seen = set([top]) stack = [(0, iter(graph[top]))] while stack: depth, children = stack[-1] try: n = next(children) stream.write("%s%s\n" % (' ' * depth, n)) if n not in seen: stack.append((depth + 1, iter(graph[n]))) seen.add(n) except StopIteration: stack.pop() return graph
def _calltree_setup_parser(parser): """ Set up the command line options for the 'openmdao call_tree' command line tool. """ parser.add_argument('method_path', nargs=1, help='Full module path to desired class method, e.g., ' '"openmdao.components.exec_comp.ExecComp.setup".') parser.add_argument('-o', '--outfile', action='store', dest='outfile', default='stdout', help='Output file. Defaults to stdout.') def _calltree_exec(options, user_args): """ Process command line args and call get_nested_calls on the specified class method. """ parts = options.method_path[0].split('.') if len(parts) < 3: raise RuntimeError("You must supply the full module path to the function, " "for example: openmdao.api.Group._setup.") class_name = parts[-2] func_name = parts[-1] modpath = '.'.join(parts[:-2]) sys.path.append(os.getcwd()) mod = importlib.import_module(modpath) klass = getattr(mod, class_name) stream_map = {'stdout': sys.stdout, 'stderr': sys.stderr} stream = stream_map.get(options.outfile) if stream is None: stream = open(options.outfile, 'w') get_nested_calls(klass, func_name, stream) def _target_iter(targets): for target in targets: if isinstance(target, ast.Tuple): for t in target.elts: yield t else: yield target class _AttrCollector(ast.NodeVisitor): """ An ast.NodeVisitor that records class attribute names. """ def __init__(self, class_dict): super().__init__() self.class_dict = class_dict self.class_stack = [] self.func_stack = [] self.names = None self.decnames = None def get_attributes(self): return self.class_dict def visit_ClassDef(self, node): full_name = '.'.join(self.class_stack[:] + [node.name]) self.class_stack.append(full_name) self.class_dict[full_name] = set() for stmt in node.body: self.visit(stmt) self.class_stack.pop() if self.func_stack: # ignore classes nested in functs del self.class_dict[full_name] def visit_FunctionDef(self, node): self.func_stack.append(node.name) for stmt in node.body: self.visit(stmt) self.func_stack.pop() if self.class_stack: # see if this is a property, and if so, treat as an attribute for dec in node.decorator_list: self.decnames = [] self.visit(dec) if len(self.decnames) == 1 and self.decnames[0] == 'property': self.class_dict[self.class_stack[-1]].add(node.name) self.decnames = None def visit_Assign(self, node): if self.class_stack: for t in _target_iter(node.targets): self.names = [] self.visit(t) if len(self.names) > 1 and self.names[0] == 'self': self.class_dict[self.class_stack[-1]].add(self.names[1]) self.names = None def visit_Attribute(self, node): if self.names is not None: self.visit(node.value) self.names.append(node.attr) def visit_Name(self, node): if self.names is not None: self.names.append(node.id) elif self.decnames is not None: self.decnames.append(node.id)
[docs]def get_class_attributes(fname, class_dict=None): """ Find all referenced attributes in all classes defined in the given file. Parameters ---------- fname : str File name. class_dict : dict or None Dict mapping class names to attribute names. Returns ------- dict The dict maps class name to a set of attribute names. """ if class_dict is None: class_dict = {} with open(fname, 'r') as f: source = f.read() node = ast.parse(source, mode='exec') visitor = _AttrCollector(class_dict) visitor.visit(node) return visitor.get_attributes()
if __name__ == '__main__': import pprint pprint.pprint(get_class_attributes(__file__))