Source code for openmdao.utils.code_utils

"""
Tools for working with code.
"""

import sys
import os
import inspect
import ast
import textwrap
import importlib
from types import LambdaType
from collections import defaultdict, OrderedDict

import networkx as nx


def _get_long_name(node):
    # If the node is an Attribute or Name node that is composed
    # only of other Attribute or Name nodes, then return the full
    # dotted name for this node. Otherwise, i.e., if this node
    # contains Subscripts or Calls, return None.
    if isinstance(node, ast.Name):
        return node.id
    elif not isinstance(node, ast.Attribute):
        return None
    val = node.value
    parts = [node.attr]
    while True:
        if isinstance(val, ast.Attribute):
            parts.append(val.attr)
            val = val.value
        elif isinstance(val, ast.Name):
            parts.append(val.id)
            break
        else:  # it's more than just a simple dotted name
            return None
    return '.'.join(parts[::-1])


class _SelfCallCollector(ast.NodeVisitor):
    """
    An ast.NodeVisitor that records calls to self.* methods.
    """

    def __init__(self, class_):
        super().__init__()
        self.self_calls = defaultdict(list)
        self.class_ = class_
        self.mro = inspect.getmro(class_)
        self.mro_names = set([c.__name__ for c in self.mro])

    def visit_Call(self, node):  # (func, args, keywords, starargs, kwargs)
        fncname = _get_long_name(node.func)
        class_ = self.class_
        if fncname is not None:
            if fncname.startswith('self.') and len(fncname.split('.')) == 2:
                shortfnc = fncname.split('.')[1]
                if shortfnc not in self.self_calls[class_]:
                    self.self_calls[class_].append(shortfnc)
                for arg in node.args:
                    self.visit(arg)
            # check for Class.func(inst) form for base class method call
            elif (len(fncname.split('.')) == 2 and fncname.split('.')[0] in self.mro_names and
                  node.args and isinstance(node.args[0], ast.Name) and node.args[0].id == 'self'):
                cname, func = fncname.split('.')
                for c in self.mro:
                    if c.__name__ == cname:
                        sub_mro = inspect.getmro(c)
                        for sub_c in sub_mro:
                            if func in sub_c.__dict__:
                                c = sub_c
                                break
                        if func not in self.self_calls[c]:
                            self.self_calls[c].append(func)
                        for arg in node.args:
                            self.visit(arg)
                        break
                else:
                    self.generic_visit(node)
            else:
                self.generic_visit(node)
        # check for super() call
        elif isinstance(node.func, ast.Attribute) and isinstance(node.func.value, ast.Call):
            callnode = node.func.value
            n = _get_long_name(callnode.func)
            # if this is a 'super' call, get the base of the specified class
            if n == 'super':  # this only works for a single call level
                if len(callnode.args) == 0:
                    sup_0 = self.mro[0].__name__
                    visit_super = True
                else:
                    sup_1 = _get_long_name(callnode.args[1])
                    sup_0 = _get_long_name(callnode.args[0])
                    visit_super = (sup_1 == 'self' and
                                   sup_0 is not None and len(sup_0.split('.')) == 1)

                if visit_super:
                    for i, c in enumerate(self.mro[:-1]):
                        if sup_0 == c.__name__:
                            # we need super of the specified class
                            sub_mro = inspect.getmro(c)
                            for sub_c in sub_mro:
                                if sub_c is not c:
                                    c = sub_c
                                    break
                            fn = node.func.attr
                            if fn not in self.self_calls[c]:
                                self.self_calls[c].append(fn)
                            for arg in node.args:
                                self.visit(arg)
                            break
                    else:
                        self.generic_visit(node)
            else:
                self.generic_visit(node)
        else:
            self.generic_visit(node)


def _find_owning_class(mro, func_name):
    """
    Return the full funcname and class where the function is first found in the class MRO.
    """
    # TODO: this won't work for classes with __slots__

    for c in mro:
        if func_name in c.__dict__:
            return '.'.join((c.__name__, func_name)), c

    return None, None


def _get_nested_calls(starting_class, class_, func_name, parent, graph, seen):
    """
    Parse the AST of the given method and all 'self' methods it calls and record owning classes.
    """
    func = getattr(class_, func_name)
    src = inspect.getsource(func)
    dedented_src = textwrap.dedent(src)

    node = ast.parse(dedented_src, mode='exec')
    visitor = _SelfCallCollector(starting_class)
    visitor.visit(node)

    seen.add('.'.join((class_.__name__, func_name)))

    # now find the actual owning class for each call
    for klass, funcset in visitor.self_calls.items():
        mro = inspect.getmro(klass)
        for f in funcset:
            full, c = _find_owning_class(mro, f)
            if full is not None:
                graph.add_edge(parent, full)
                if full not in seen:
                    _get_nested_calls(starting_class, c, f, full, graph, seen)


[docs]def get_nested_calls(class_, method_name, stream=sys.stdout): """ Display the call tree for the specified class method and all 'self' class methods it calls. Parameters ---------- class_ : class The starting class. method_name : str The name of the class method. stream : file-like The output stream where output will be displayed. Returns ------- networkx.DiGraph A graph containing edges from methods to their sub-methods. """ # moved this class def in here to keep the numpy doc scraper from barfing due to # stuff in nx.DiGraph. class OrderedDiGraph(nx.DiGraph): """ A DiGraph using OrderedDicts for internal storage. """ node_dict_factory = OrderedDict adjlist_dict_factory = OrderedDict edge_attr_dict_factory = OrderedDict graph = OrderedDiGraph() seen = set() top = object() full, klass = _find_owning_class(inspect.getmro(class_), method_name) if full is None: print("Can't find function '%s' in class '%s'." % (method_name, class_.__name__)) else: graph.add_edge(top, full) parent = full _get_nested_calls(class_, klass, method_name, parent, graph, seen) if graph and stream is not None: seen = set([top]) stack = [(0, iter(graph[top]))] while stack: depth, children = stack[-1] try: n = next(children) stream.write("%s%s\n" % (' ' * depth, n)) if n not in seen: stack.append((depth + 1, iter(graph[n]))) seen.add(n) except StopIteration: stack.pop() return graph
def _calltree_setup_parser(parser): """ Set up the command line options for the 'openmdao call_tree' command line tool. """ parser.add_argument('method_path', nargs=1, help='Full module path to desired class method, e.g., ' '"openmdao.components.exec_comp.ExecComp.setup".') parser.add_argument('-o', '--outfile', action='store', dest='outfile', default='stdout', help='Output file. Defaults to stdout.') def _calltree_exec(options, user_args): """ Process command line args and call get_nested_calls on the specified class method. """ parts = options.method_path[0].split('.') if len(parts) < 3: raise RuntimeError("You must supply the full module path to the function, " "for example: openmdao.api.Group._setup.") class_name = parts[-2] func_name = parts[-1] modpath = '.'.join(parts[:-2]) sys.path.append(os.getcwd()) mod = importlib.import_module(modpath) klass = getattr(mod, class_name) stream_map = {'stdout': sys.stdout, 'stderr': sys.stderr} stream = stream_map.get(options.outfile) if stream is None: stream = open(options.outfile, 'w') get_nested_calls(klass, func_name, stream) def _target_iter(targets): for target in targets: if isinstance(target, ast.Tuple): for t in target.elts: yield t else: yield target class _AttrCollector(ast.NodeVisitor): """ An ast.NodeVisitor that records class attribute names. """ def __init__(self, class_dict): super().__init__() self.class_dict = class_dict self.class_stack = [] self.func_stack = [] self.names = None self.decnames = None def get_attributes(self): return self.class_dict def visit_ClassDef(self, node): full_name = '.'.join(self.class_stack[:] + [node.name]) self.class_stack.append(full_name) self.class_dict[full_name] = set() for stmt in node.body: self.visit(stmt) self.class_stack.pop() if self.func_stack: # ignore classes nested in functs del self.class_dict[full_name] def visit_FunctionDef(self, node): self.func_stack.append(node.name) for stmt in node.body: self.visit(stmt) self.func_stack.pop() if self.class_stack: # see if this is a property, and if so, treat as an attribute for dec in node.decorator_list: self.decnames = [] self.visit(dec) if len(self.decnames) == 1 and self.decnames[0] == 'property': self.class_dict[self.class_stack[-1]].add(node.name) self.decnames = None def visit_Assign(self, node): if self.class_stack: for t in _target_iter(node.targets): self.names = [] self.visit(t) if len(self.names) > 1 and self.names[0] == 'self': self.class_dict[self.class_stack[-1]].add(self.names[1]) self.names = None def visit_Attribute(self, node): if self.names is not None: self.visit(node.value) self.names.append(node.attr) def visit_Name(self, node): if self.names is not None: self.names.append(node.id) elif self.decnames is not None: self.decnames.append(node.id)
[docs]def get_class_attributes(fname, class_dict=None): """ Find all referenced attributes in all classes defined in the given file. Parameters ---------- fname : str File name. class_dict : dict or None Dict mapping class names to attribute names. Returns ------- dict The dict maps class name to a set of attribute names. """ if class_dict is None: class_dict = {} with open(fname, 'r') as f: source = f.read() node = ast.parse(source, mode='exec') visitor = _AttrCollector(class_dict) visitor.visit(node) return visitor.get_attributes()
[docs]def is_lambda(f): """ Return True if the given function is a lambda function. Parameters ---------- f : function The function to check. Returns ------- bool True if the given function is a lambda function. """ return isinstance(f, LambdaType) and f.__name__ == "<lambda>"
[docs]class LambdaPickleWrapper(object): """ A wrapper for a lambda function that allows it to be pickled. Parameters ---------- lambda_func : function The lambda function to be wrapped. Attributes ---------- _func : function The lambda function. _src : str The isolated source of the lambda function. """
[docs] def __init__(self, lambda_func): """ Initialize the wrapper. Parameters ---------- lambda_func : function The lambda function to be wrapped. """ self._func = lambda_func self._src = None
def __call__(self, *args, **kwargs): """ Call the lambda function. Parameters ---------- *args : list Positional arguments. **kwargs : dict Keyword arguments. Returns ------- object The result of the lambda function. """ return self._func(*args, **kwargs) def __getstate__(self): """ Return the state of this object for pickling. The lambda function is converted to a string for pickling. Returns ------- dict The state of this object. """ state = self.__dict__.copy() state['_func'] = self._getsrc() return state def __setstate__(self, state): """ Restore the state of this object after pickling. Parameters ---------- state : dict The state of this object. """ self.__dict__.update(state) self._func = eval(state['_func']) # nosec def _getsrc(self): if self._src is None: self._src = _LambdaSrcFinder(self._func).src if self._src is None: raise RuntimeError("The fix for pickling lambda functions only works for python " "3.9 and above. Try updating to a newer python version or " "replacing the lambda with a regular function.") return self._src
class _LambdaSrcFinder(ast.NodeVisitor): """ Given a lambda function, isolate the lambda function source from any surrounding code. """ def __init__(self, func): super().__init__() self.src = None # note that inspect.getsource gives the source for the line that contains the lambda # function, so we have to isolate the lambda function itself self.visit(ast.parse(textwrap.dedent(inspect.getsource(func)), filename='<string>')) def visit_Lambda(self, node): if self.src is not None: # it's possible to have multiple lambdas defined on the same line, so raise an error # if we find more than one. raise RuntimeError("Only one lambda function is allowed per line when using " "_LambdaWrapper.") try: self.src = ast.unparse(node) except AttributeError: # ast.unparse was added in python 3.9 self.src = None if __name__ == '__main__': import pprint pprint.pprint(get_class_attributes(__file__))