#!/usr/bin/env python
# Copyright (c) 2009 Aymeric Augustin

"""Determine which import statements may be unnecessary in Python modules.

This helps cleaning imports, especially after a lot of refactoring.

This script inspects each module's AST in two steps:
  - it searches import statements;
  - it searches references to identifiers.
Then it determines which identifiers are imported but never referenced.

Using the AST ensures that every possible syntax involving 'import'
is properly handled.

Given the dynamic nature of Python, this script can not be perfect.
There are ways to reference a module without using its identifier.
However, it works reasonnably well in non-pathologic situations.

Here are the limitations you must be aware of:
  - If a module is only referenced in another fashion than through its
    identifier, it will be reported as unused (false positive)
  - If the name of an unused module is used as an identifier for something
    else (for instance, sys = 1), it will not be reported (false negative)
  - from ... import * statements are ignored:
      . checking if at least one symbol imported by such a statement is
        used would be complicated
      . this is considered bad style, because the list of imported symbols
        is not explicit, so exercise caution anyway
  - from __future__ import ... statements receive a special treatment:
      . nested_scopes and generators are always reported
      . with_statement is checked
      . division and absolute_import are always ignored, because there is
        no way to check them
      . features that may be added in the future will also be ignored
    This seems correct with Python 2.5, the current mainstream version.

This script does not attempt to determine which import statements
may be missing. This is the job of your testing suite.
"""

from __future__ import with_statement

try:
    import ast
except ImportError:
    import _ast as ast
import logging
import os.path
import sys


# ~doc/library/ast.html#abstract-grammar
# Import statements are:
# | Import(alias* names)
# | ImportFrom(identifier module, alias* names, int? level)
# alias = (identifier name, identifier? asname)


def find_imports(tree):
    imports = set()
    # Look for imports
    if isinstance(tree, ast.Import):
        imports = set(n.asname or n.name for n in tree.names)
    elif isinstance(tree, ast.ImportFrom):
        imports = set(n.asname or n.name for n in tree.names)
        # Handle from ... import *
        if '*' in imports:
            logging.warning('    Will not check line %d:'
                            ' from %s import *',
                            tree.lineno, tree.module)
            imports.remove('*')
        # Handle from __future__ import ...
        if tree.module == '__future__':
            for feature in tree.names:
                feature = feature.name
                imports.remove(feature)
                if feature in ('nested_scopes', 'generators'):
                    logging.warning('    Probably obsolete:'
                                    ' line %d: from %s import %s',
                                    tree.lineno, tree.module, feature)
                elif feature == 'with_statement':
                    imports.add('__with_statement__')
    # Recurse if appropriate
    if not isinstance(tree, ast.AST) or tree._fields is None:
        return imports
    for field in tree._fields:
        children = getattr(tree, field)
        if isinstance(children, ast.AST):
            imports.update(find_imports(children))
        elif isinstance(children, list):
            for child in children:
                imports.update(find_imports(child))
    return imports


# ~doc/library/ast.html#abstract-grammar
# Identifiers occur in:
# | Attribute(expr value, identifier attr, expr_context ctx)
# | Name(identifier id, expr_context ctx)
# We also want to check:
# | With(expr context_expr, expr? optional_vars, stmt* body)


def get_identifier_name(attr):
    """Get the fully qualified name of an attribute."""
    if type(attr) == ast.Name:
        return attr.id
    elif type(attr) == ast.Attribute:
        basename = get_identifier_name(attr.value)
        if basename:
            return '%s.%s' % (basename, attr.attr)


def find_names(tree):
    names = set()
    # Look for names
    if isinstance(tree, ast.Name) or isinstance(tree, ast.Attribute):
        name = get_identifier_name(tree)
        if name:
            names = set([name])
    elif isinstance(tree, ast.With):
        names = set(['__with_statement__'])
    # Recurse
    if not isinstance(tree, ast.AST) or tree._fields is None:
        return names
    for field in tree._fields:
        children = getattr(tree, field)
        if isinstance(children, ast.AST):
            names.update(find_names(children))
        elif isinstance(children, list):
            for child in children:
                names.update(find_names(child))
    return names


def find_imports_and_names(module):
    with open(module) as handle:
        tree = compile(handle.read(), os.path.basename(module),
                       'exec', ast.PyCF_ONLY_AST, True)
        return find_imports(tree), find_names(tree)


def check_imports(modules):
    import_and_names = {}
    l = len(modules)
    for i, module in enumerate(modules):
        if l > 1:
            logging.info("Analysing %s (%d/%d)", module, i + 1, l)
        try:
            import_and_names[module] = find_imports_and_names(module)
        except IOError, e:
            logging.error("Error while reading module %s:", module)
            logging.error("    %s", e)
            sys.exit(1)
        except SyntaxError, e:
            logging.error("Error while compiling module %s:", module)
            logging.error("    %s", e)
            logging.error("This module will not be checked")
    all_valid = True
    if l > 1:
        logging.info('-' * 72)
    for m, (i, n) in import_and_names.iteritems():
        i.difference_update(n)
        if '__with_statement__' in i:
            all_valid = False
            i.remove('__with_statement__')
            logging.error("Certainly unused import in %s: with_statement", m)
        if i:
            all_valid = False
            imports = ', '.join(sorted(i))
            logging.error("Possibly unused imports in %s: %s", m, imports)
    if all_valid:
        logging.error("No unused imports were found")
    return all_valid


def usage():
    script = os.path.basename(sys.argv[0])
    print 'Usage: %s module.py [module2.py ...]' % script
    print '       %s --help' % script


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO, format='%(message)s')
    if len(sys.argv) == 1: # no args
        usage()
        sys.exit(2)
    if sys.argv[1] == '--help':
        usage()
        print '\n', __doc__
        sys.exit(0)
    sys.exit(0 if check_imports(sys.argv[1:]) else 1)
