#!/usr/bin/env python
# Copyright (c) 2009 Aymeric Augustin

"""Draw the class graph of a set of Python modules.

This script must be executed from a directory of sys.path, and arguments
must be passed as relative paths. Otherwise, it may fail to:
  - import other modules located within the same package;
  - determine fully qualified module identifiers.

Simple example:
$ PYTHONPATH=. class_graph.py module.py

Advanced example:
$ find <package> -name '*.py' -a -not -name 'test_*.py' -print0 | \
> PYTHONPATH=. xargs -0 class_graph.py

This script inspects each module in three steps.

1) It resolves all imports, mapping module-local names to fully qualified
names (e.g. relative to a directory of sys.path). This is done by walking the
module's AST and looking for import statements. For each imported module, the
local directory is searched ("relative" import), then each directory of
sys.path ("absolute" import), in order to determine the fully qualified name
of the module.

2) It searches for classes definitions by walking the AST a second time. The
fully qualified name of each class is known, provided the rules for passing
arguments are followed (see above). The list of the base classes of each class
is saved. At this step, base classes are identified by their local names.

3) Local names of base classes are transformed into fully qualified names.
This is done by recursively inspecting the symbol tables within the module.
Since each class creates a local scope with its own symbol table, this
provides the list of the classes defined in the module. For each class, the
list of its base classes is looked up in the results on step 2. Then the
identifier scoping rules are applied to determine the fully qualified name
of each base class, using the results of step 1.

Finally, a consistency check is performed, to ensure that the same set
of classes was gathered by walking the abstract syntax tree and the symbol
tables of the module.

Given the dynamic nature of Python, this script can not be perfect.
There are ways to reference a class without using an identifier, and such
constructs can not be statically resolved. Example:

# Supported
class A(object):
    pass

# Unsupported
o = object
class B(o):
    pass

However, the scripts works reasonnably well in non-pathologic situations.
"""

from __future__ import with_statement

import __builtin__
try:
    import ast
except ImportError:
    import _ast as ast
import imp
import logging
import os.path
import sys
import symtable

try:
    import pydot
except ImportError:
    sys.stderr.write("Please install pydot <http://dkbza.org/pydot.html>"
                     " for graph rendering.\n")


# Import rules:

# ~doc/library/ast.html#abstract-grammar
# Import statements are:
# | Import(alias* names)
# | ImportFrom(identifier module, alias* names, int? level)
# alias = (identifier name, identifier? asname)

# ~doc/tutorial/modules.html#the-module-search-path
# Modules are searched first in the current directory, then in the list of
# directories given by the variable sys.path.

# See also:
# ~doc/reference/simple_stmts.html#import


# Useful special attributes:

# ~doc/tutorial/modules.html
# Within a module, the module's name (as a string) is available as the value
# of the global variable __name__.

# ~doc/tutorial/modules.html#packages-in-multiple-directories
# Packages support one more special attribute, __path__. This is initialized
# to be a list containing the name of the directory holding the package's
# __init__.py before the code in that file is executed.

# __package__ is used only for relative imports, see PEP 366.

# See also:
# Demo/imputil/knee.py in the Python source distribution


def import_module(name, parent=None):
    """Try to import module 'name'.

    'name' must be a simple name e.g. not contain dots.

    If 'parent' is provided, the module is searched within this package,
    and only there.
    """
    assert '.' not in name

    # Test if the module/package is already imported
    fullname = '%s.%s' % (parent.__name__, name) if parent else name
    module = sys.modules.get(fullname, None)
    # Not so obvious semantics of sys.modules[...] = None
    # http://www.python.org/doc/essays/packages.html
    if module is not None:
        return module

    # Locate the module/package
    # imp.find_module makes it possible to specify an exact search path
    f, p, d = imp.find_module(name, parent and parent.__path__)

    # Import the module/package
    try:
        # Since this module may be located arbitrarily within a package, and
        # may import other modules located in the same directory, we adjust
        # sys.path during the import
        # if path:
        #     old_path, sys.path = sys.path, path + sys.path
        module = imp.load_module(fullname, f, p, d)
        # if path:
        #     sys.path = old_path
    finally:
        if f:
            f.close()
    if parent:
        setattr(parent, name, module)
    return module


def find_within_module(name, parent):
    """Locate object 'name' within module or package 'parent'.

    'name' may be a compound name e.g. contain dots.
    """
    head, _, tail = name.partition('.')

    # Test if parent is the final object
    if head == '':
        return parent

    # Test if parent contains the final object
    if tail == '':
        try:
            return getattr(parent, name)
        except AttributeError:
            pass

    # If 'parent' is a package, recursively search subpackages or modules
    # In other words, try to import the final object into parent
    if hasattr(parent, '__path__'):
        module = import_module(head, parent)
        if module:
            return find_within_module(tail, module)

    # Give up
    raise NameError("%r not found in %r." % (head, parent.__name__))


def resolve_import(name, pymod):
    """Translate imported identifier 'name' into a fully qualified identifier.
    
    'name' must be imported in module 'pymod'.
    """
    # Since only modules are accepted as input, pymod can not be a package.
    # It may be an __init__ module, though.
    assert not hasattr(pymod, '__path__')

    if '.' in pymod.__name__:
        # Maybe we're importing from a parent package, maybe from sys.path
        parname = pymod.__name__.rpartition('.')[0]
        pypar = sys.modules[parname]
        assert pypar.__name__ == parname
    else:
        # We're importing from sys.path
        pypar = None

    # Since name is imported and fully qualified, let's import the first
    # component, which is a module or a package, and then search inside
    head, _, tail = name.partition('.')
    
    # Test for an import in the same package
    if pypar:
        try:
            pypkg = import_module(head, pypar)
        except ImportError:
            pass
        else:
            # will raise a NameError if not found
            if find_within_module(tail, pypkg):
                return '%s.%s' % (pypar.__name__, name)

    # Test for an import in a directory of sys.path
    try:
        pypkg = import_module(head)
    except ImportError:
        logging.critical("Could not import %r. Is PYTHONPATH set correctly?",
                         name)
    else:
        # will raise a NameError if not found
        if find_within_module(tail, pypkg):
            return name



def map_import_bindings(tree, pymod):
    """Map imported names in AST 'tree' to fully qualified names.

    This method handles imports in module 'pymod'.
    """
    imports = dict()
    # Look for import ...
    if isinstance(tree, ast.Import):
        for n in tree.names:
            relative = n.asname or n.name
            absolute = resolve_import(n.name, pymod)
            imports[relative] = absolute
            logging.debug("Mapping imported symbol %r to %r",
                          relative, absolute)
    # Look for from ... import ...
    elif isinstance(tree, ast.ImportFrom) and tree.module != '__future__':
        for n in tree.names:
            relative = n.asname or n.name
            if relative == '*':
                # This has the side effect of importing tree.module, which
                # can then be found is sys.modules
                imported = resolve_import(tree.module, pymod)
                for relative in dir(sys.modules[imported]):
                    absolute = '%s.%s' % (imported, relative)
                    imports[relative] = absolute
                    logging.debug("Mapping imported symbol %r to %r",
                                  relative, absolute)
            else:
                imported = '%s.%s' % (tree.module, n.name)
                absolute = resolve_import(imported, pymod)
                imports[relative] = absolute
                logging.debug("Mapping imported symbol %r to %r",
                              relative, absolute)
    # Recurse if appropriate
    if not isinstance(tree, ast.AST) or tree._fields is None:
        return imports
    for field in tree._fields:
        children = getattr(tree, field)
        if isinstance(children, ast.AST):
            imports.update(map_import_bindings(children, pymod))
        elif isinstance(children, list):
            for child in children:
                imports.update(map_import_bindings(child, pymod))
    return imports


# ~doc/library/ast.html#abstract-grammar
# Identifiers occur in:
# | Attribute(expr value, identifier attr, expr_context ctx)
# | Name(identifier id, expr_context ctx)
# We also want to check:
# | With(expr context_expr, expr? optional_vars, stmt* body)


def get_identifier_name(attr):
    """Get the fully qualified name of an attribute."""
    if type(attr) == ast.Name:
        return attr.id
    elif type(attr) == ast.Attribute:
        basename = get_identifier_name(attr.value)
        if basename:
            return '%s.%s' % (basename, attr.attr)


# ~doc/library/ast.html#abstract-grammar
# Class definition statements are:
# | ClassDef(identifier name, expr* bases, stmt* body, expr *decorator_list)


# ~doc/reference/compound_stmts.html#class-definitions
# > Each item in the inheritance list should evaluate to a class object or
# > class type which allows subclassing.
# However, it is most often a list of Names or Attributes. For the sake of
# simplicity, so we will only support this.


def build_partial_inheritance(tree, path):
    """Map class names in AST 'tree' to the list of their base classes.

    Fully qualified names are used for the classes, local names for the
    base classes.

    This method handles classes in scope 'path'.
    """
    inheritance = dict()
    # Look for classes
    if isinstance(tree, ast.ClassDef):
        path = '%s.%s' % (path, tree.name)
        inheritance[path] = map(get_identifier_name, tree.bases)
        logging.debug("Base classes for %r are %r.", path, inheritance[path])
    # Functions also create symbol table nesting
    # Modules need not be taken into account
    elif isinstance(tree, ast.FunctionDef):
        path = '%s.%s' % (path, tree.name)
    # Recurse if appropriate
    if not isinstance(tree, ast.AST) or tree._fields is None:
        return inheritance
    for field in tree._fields:
        children = getattr(tree, field)
        if isinstance(children, ast.AST):
            inheritance.update(build_partial_inheritance(children, path))
        elif isinstance(children, list):
            for child in children:
                inheritance.update(build_partial_inheritance(child, path))
    return inheritance


# ~doc/reference/executionmodel.html#naming-and-binding
# Read it. Reminders:
#   - block = (module), class, function
#   - block scope
#   - name = local, global, free
#   - global => module + builtin, else scopes


def resolve_imported_symbol(name, imports):
    """Transform an imported class name into a fully qualified name.

    This works by trying to match against imported names at each scope level.

    'imports' is a mapping describing the import bindings.
    """
    head, _, tail = name.partition('.')
    while True:
        # Look for a match
        try:
            match = imports[head]
        except KeyError:
            pass
        else:
            if tail:
                match = '%s.%s' % (match, tail)
            return match
        # Iterate
        if tail:
            h, _, tail = tail.partition('.')
            head = '%s.%s' % (head, h)
        else:
            break


def build_full_inheritance(table, path, partial, imports,
                           module, module_path,
                           parent=None, parent_path=None):
    """Map class names in symbol 'table' to the list of their base classes.

    Fully qualified names are used for the classes and the base classes.

    This method handles classes in scope 'path'.

    'partial' is the output of build_partial_inheritance.
    'imports' is the output of map_import_bindings.
    'module' is the symbol table for the module currently inspected, and
    'path' its fully qualified import path.
    'parent' is the symbol table for the parent scope, and 'parent_path'
    its fully qualified scope.
    """
    inheritance = dict()
    # Fully resolve inheritance of classes based on the symbol table
    if table.get_type() == 'class':
        inheritance[path] = []
        for cl in partial[path]:
            head, _, tail = cl.partition('.')
            symbol = parent.lookup(head)

            # Handle local symbols
            # > Each assignment or import statement occurs within a block
            # > defined by a class or function definition or at the module
            # > level (the top-level code block).
            # I assume that local symbols are (assigned xor imported)
            if symbol.is_local() and not symbol.is_global():

                if symbol.is_assigned() and not symbol.is_imported():
                    match = parent_path + '.' + cl
                    logging.debug("Assigned local symbol %r matched"
                                  " to %r.", cl, match)
                    inheritance[path].append(match)

                elif not symbol.is_assigned() and symbol.is_imported():
                    match = resolve_imported_symbol(cl, imports)
                    if match:
                        inheritance[path].append(match)
                        logging.debug("Imported local symbol %r matched"
                                      " to %r.", cl, match)
                    else:
                        logging.critical("%r is imported, but was not"
                                         " found in imports.", cl)

                else:
                    logging.critical("Assertion failed: local symbol %r is"
                                     " not (assigned xor imported) in"
                                     " block.", head)

            # Handle global symbols
            # > The global namespace is searched first. If the name is not
            # > found there, the builtin namespace is searched.
            elif not symbol.is_local() and symbol.is_global():

                # A global symbol must be defined at the module level
                try:
                    symbol = module.lookup(head)
                except KeyError:
                    if head in dir(__builtin__):
                        logging.debug("Builtin symbol %r.", cl)
                        inheritance[path].append(cl)
                    else:
                        match = resolve_imported_symbol(cl, imports)
                        if match:
                            inheritance[path].append(match)
                            logging.debug("Imported global symbol %r matched"
                                          " to %r.", cl, match)
                        else:
                            logging.warning("Match failed: global symbol %r"
                                            " was not found.", head)
                    continue

                # This is the same logic at the module level, as above at the
                # block level
                if symbol.is_local() and not symbol.is_global():

                    if symbol.is_assigned() and not symbol.is_imported():
                        match = module_path + '.' + cl
                        logging.debug("Assigned global symbol %r matched"
                                      " to %r.", cl, match)
                        inheritance[path].append(match)

                    elif not symbol.is_assigned() and symbol.is_imported():
                        match = resolve_imported_symbol(cl, imports)
                        if match:
                            inheritance[path].append(match)
                            logging.debug("Imported global symbol %r matched"
                                          " to %r.", cl, match)
                        else:
                            logging.critical("%r is imported, but was not"
                                             " found in imports.", cl)

                    else:
                        logging.critical("Assertion failed: local symbol %r is"
                                         " not (assigned xor imported) in"
                                         " module.", head)

                # A global symbol at the module level is builtin, bound by a
                # 'from ... import *' or undefined
                elif not symbol.is_local() and symbol.is_global():
                    if head in dir(__builtin__):
                        logging.debug("Builtin symbol %r.", cl)
                        inheritance[path].append(cl)
                    else:
                        match = resolve_imported_symbol(cl, imports)
                        if match:
                            inheritance[path].append(match)
                            logging.debug("Imported global symbol %r matched"
                                          " to %r.", cl, match)
                        else:
                            logging.warning("Match failed: global symbol %r"
                                            " was not assigned.", head)

                else:
                    logging.critical("Assertion failed: symbol %r is not"
                                     " (local xor global) in module.", head)

            else:
                logging.critical("Assertion failed: symbol %r is not"
                                 " (local xor global) in block.", head)

    # Recurse
    for child in table.get_children():
        child_path = path + '.' + child.get_name()
        inheritance.update(build_full_inheritance(child, child_path, partial,
                                                 imports, module, module_path, table, path))

    return inheritance


def find_classes(module):
    """Maps classes defined in a module to their base classes.

    Fully-qualified identifiers are used for all classes.
    """
    root, ext = os.path.splitext(module)
    if ext != '.py':
        logging.error("Module %r does not have a .py extension.", module)
        return

    # Transform the module path into a module name
    head, modname = root, []
    while head:
        head, tail = os.path.split(head)
        modname.append(tail)
    modname = '.'.join(reversed(modname))

    # Import the module in the global scope
    try:
        pymod = sys.modules[modname]
    except KeyError:
        __import__(modname, globals={}, locals={}, fromlist=[], level=0)
        pymod = sys.modules[modname]
    assert pymod.__name__ == modname

    # Compile the abstract syntax tree and the symbol table for the module
    with open(module) as handle:
        tree = compile(handle.read(), os.path.basename(module),
                       'exec', ast.PyCF_ONLY_AST, True)
        handle.seek(0)
        table = symtable.symtable(handle.read(), module, 'exec')

    # Compute a mapping of short identifiers to fully-qualified
    # identifiers for imported modules
    logging.debug("")
    logging.debug("Mapping import bindings")
    logging.debug("-----------------------")
    logging.debug("")
    imp_bind = map_import_bindings(tree, pymod)

    # Compute a partial inheritance table mapping fully-qualified
    # identifiers to short identifiers of parent classes
    logging.debug("")
    logging.debug("Building partial inheritance table")
    logging.debug("----------------------------------")
    logging.debug("")
    part_inh = build_partial_inheritance(tree, pymod.__name__)

    # Compute a full inheritance table mapping fully-qualified
    # identifiers to fully-qualified identifiers of parent classes
    logging.debug("")
    logging.debug("Building full inheritance table")
    logging.debug("-------------------------------")
    logging.debug("")
    full_inh = build_full_inheritance(table, pymod.__name__,
                                      part_inh, imp_bind,
                                      table, pymod.__name__)

    # Consistency check
    ast_set, st_set = set(part_inh.keys()), set(full_inh.keys())
    try:
        assert ast_set == st_set
    except AssertionError:
        ast_minus_st, st_minus_ast = ast_set - st_set, st_set - ast_set
        logging.critical("Traversing the abstract syntax tree and the"
                         " symbol table of %s yield different results:"
                         " AST - ST = %r, ST - AST = %r.",
                         module, ast_minus_st, st_minus_ast)
        raise

    logging.debug("")
    logging.debug("")
    return full_inh


def class_graph(modules):
    l = len(modules)
    classes = {}
    for i, module in enumerate(modules):
        if l > 1:
            logging.info("Analysing %s (%d/%d)", module, i + 1, l)
        try:
            classes.update(find_classes(module))
        except IOError, e:
            logging.error("Error while reading module %s:", module)
            logging.error("    %s", e)
            sys.exit(1)
        except SyntaxError, e:
            logging.error("Error while compiling module %s:", module)
            logging.error("    %s", e)
            logging.error("This module will not be checked")
    return classes


def render_graph(modules):
    classes = class_graph(modules)
    graph = pydot.Dot(size='11.69,16.54', ratio='compress', rankdir='LR')

    nodes = dict()
    for child, parents in classes.iteritems():
        if child not in nodes:
            nodes[child] = pydot.Node(child)
            graph.add_node(nodes[child])
        for parent in parents:
            if parent not in nodes:
                nodes[parent] = pydot.Node(parent)
                graph.add_node(nodes[parent])
            graph.add_edge(pydot.Edge(nodes[parent], nodes[child]))

    graph.write_pdf('classes.pdf', prog='dot')


def usage():
    script = os.path.basename(sys.argv[0])
    print 'Usage: %s module.py [module2.py ...]' % script
    print '       %s --help' % script


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO, format='%(message)s')
    if len(sys.argv) == 1: # no args
        usage()
        sys.exit(2)
    if sys.argv[1] == '--help':
        usage()
        print '\n', __doc__
        sys.exit(0)
    sys.exit(0 if render_graph(sys.argv[1:]) else 1)
