#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright © 2005 Progiciels Bourbeau-Pinard inc.
# François Pinard <pinard@iro.umontreal.ca>, 2005.

u"""\
Execute a validation suite built in pylib's py.test style.

Usage: pytest [OPTION]... [PATH]...

Options:
  -h   Print this help and exit right away.
  -v   Produce one line per test, instead of a dot per test.
  -n   Do not capture stdout nor stderr.
  -p   Profile validation suite (through "lsprof").

  -f PREFIX     Use PREFIX instead of "test_" for file names.
  -s FILE       Save ordinals of failed tests, one per line.
  -o ORDINALS   Only consider tests having these execution ordinals.
  -k PATTERN    Only retain tests which names match PATTERN.
  -x PATTERN    Exclude tests which names match PATTERN.
  -l LIMIT      Stop the validation suite after LIMIT errors.

If -l is not used, the validation will stop after 10 errors.  ORDINALS is
a sequence of comma-separated integers.  Options -k, -o and -x may be
repeated; then a test should match at least one of -k options if any,
one of -o options is any, and none of -x options.

When PATH names a directory, it is recursively searched to find files
matching "test_.*\.py".  Otherwise, it should give a ".py" ended file name
containing tests.  If no PATH are given, the current directory is implied.

Test progression is first displayed on standard error.  Then, unless
-s is selected, failed tests are detailed on stdout and, if there is at
least one such failed test, the return status of this program is non-zero.
"""

# This tool implements a minimal set of specifications stolen from the
# excellent Codespeak's py.test, at a time I really needed py.test to be
# more Unicode-aware.

__metaclass__ = type
import inspect, os, sys, time, traceback
from StringIO import StringIO

# How many displayable characters in an output line.
WIDTH = 79

class Limit_Reached(Exception): pass

class Main:
    prefix = 'test_'
    pattern = []
    exclusion = []
    ordinals = []
    verbose = False
    profile = False
    limit = 10
    capture = True
    save = None

    # For handling setup/teardown laziness.
    delayed_setup_module = None
    delayed_setup_class = None
    did_tests_in_module = False
    did_tests_in_class = False

    def main(self, *arguments):
        if sys.getdefaultencoding() == 'ascii':
            sys.stdout = Friendly_StreamWriter(sys.stdout)
            sys.stderr = Friendly_StreamWriter(sys.stderr)
        import getopt
        options, arguments = getopt.getopt(arguments, u'f:hk:l:no:ps:vx:')
        for option, value in options:
            if option == u'-f':
                self.prefix = value
            elif option == u'-h':
                sys.stdout.write(__doc__)
                return
            elif option == u'-k':
                self.pattern.append(value)
            elif option == u'-l':
                self.limit = int(value)
            elif option == u'-n':
                self.capture = False
            elif option == u'-o':
                self.ordinals += [
                    int(text) for text in value.replace(u',', u' ').split()]
            elif option == u'-p':
                self.profile = True
            elif option == u'-s':
                self.save = value
            elif option == u'-v':
                self.verbose = True
            elif option == u'-x':
                self.exclusion.append(value)
        if not arguments:
            arguments = u'.',
        if self.pattern:
            import re
            self.pattern = re.compile(u'|'.join(self.pattern))
        else:
            self.pattern = None
        if self.exclusion:
            import re
            self.exclusion = re.compile(u'|'.join(self.exclusion))
        else:
            self.exclusion = None
        write = sys.stderr.write
        self.failures = []
        self.total_count = 0
        self.skip_count = 0
        start_time = time.time()
        if self.profile:
            try:
                import lsprof
            except ImportError:
                write(u"WARNING: profiler unavailable.\n")
                self.profiler = None
            else:
                self.profiler = lsprof.Profiler()
        else:
            self.profiler = None
        try:
            try:
                for argument in arguments:
                    for file_name in self.each_file(argument):
                        self.identifier = file_name
                        self.column = 0
                        self.counter = 0
                        directory, base = os.path.split(file_name)
                        sys.path.insert(0, directory)
                        try:
                            try:
                                module = __import__(base[:-3])
                            except ImportError:
                                if self.save:
                                    self.failures.append(self.total_count + 1)
                                else:
                                    tracing = StringIO()
                                    traceback.print_exc(file=tracing)
                                    self.failures.append(
                                        (self.total_count + 1, file_name,
                                         None, None,
                                         None, None,
                                         str(tracing.getvalue())))
                            else:
                                self.handle_module(file_name, module)
                        finally:
                            del sys.path[0]
                        if self.counter and not self.verbose:
                            text = u'(%d)' % self.counter
                            if self.column + 1 + len(text) >= WIDTH:
                                write(u'\n%5d ' % self.counter)
                            else:
                                text = u' ' + text
                            write(text + u'\n')
            except Exit, exception:
                if not self.verbose:
                    write(u'\n')
                write(u'\n* %s *\n' % str(exception))
            except Limit_Reached:
                if not self.verbose:
                    write(u'\n')
                if not self.save:
                    if len(self.failures) == 1:
                        write(u'\n* One error already! *\n')
                    else:
                        write(u'\n* %d errors already! *\n' % self.limit)
        finally:
            if self.profiler is not None:
                stats = lsprof.Stats(self.profiler.getstats())
                stats.sort(u'inlinetime')
                write(u'\n')
                stats.pprint(15)
            if self.failures:
                if len(self.failures) == 1:
                    text = u"one FAILED test"
                else:
                    text = u"%d FAILED tests" % len(self.failures)
                first = False
            else:
                text = u''
                first = True
            good_count = (self.total_count - self.skip_count
                          - len(self.failures))
            if good_count:
                if first:
                    if good_count == 1:
                        text += u"one good test"
                    else:
                        text += u"%d good tests" % good_count
                    first = False
                else:
                    if good_count == 1:
                        text += u", one good"
                    else:
                        text += u", %d good" % good_count
            if self.skip_count:
                if first:
                    if self.skip_count == 1:
                        text += u"one skipped test"
                    else:
                        text += u"%d skipped tests" % self.skip_count
                    first = False
                else:
                    if self.skip_count == 1:
                        text += u", one skipped"
                    else:
                        text += u", %d skipped" % self.skip_count
            if first:
                text = u"No test"
            summary = (u"\nSummary: %s in %.2f seconds.\n"
                       % (text, time.time() - start_time))
            write(summary)
        if self.save:
            write = file(self.save, u'w').write
            for ordinal in self.failures:
                write(u'%d\n' % ordinal)
        else:
            write = sys.stdout.write
            for (ordinal, prefix, function, arguments, stdout, stderr,
                 tracing) in self.failures:
                write(u'\n' + u'=' * WIDTH + u'\n')
                write(u'%d. %s\n' % (ordinal, prefix))
                if function and function.__name__ != os.path.basename(prefix):
                    write(u"    Fonction %s\n" % function.__name__)
                if arguments:
                    for counter, argument in enumerate(arguments):
                        write(u"    Arg %d = %r\n" % (counter + 1, argument))
                for buffer, titre in ((stdout, u'STDOUT'), (stderr, u'STDERR')):
                    if buffer:
                        write(u'\n' + titre + u' >>>\n')
                        write(buffer)
                        if not buffer.endswith(u'\n'):
                            write(u'\n')
                write(u'-' * WIDTH + u'\n')
                write(tracing)
            if self.failures:
                write(summary)
                sys.exit(1)

    def each_file(self, path):
        if os.path.isdir(path):
            stack = [path]
            while stack:
                directory = stack.pop(0)
                for base in sorted(os.listdir(directory)):
                    file_name = os.path.join(directory, base)
                    if base.startswith(self.prefix) and base.endswith(u'.py'):
                        yield file_name
                    elif os.path.isdir(file_name):
                        stack.append(file_name)
        else:
            if not path.endswith(u'.py'):
                sys.exit("%s: Should have a `.py' suffix." % path)
            yield path

    def handle_module(self, prefix, module):
        collection = []
        for name, objet in inspect.getmembers(module):
            if name.startswith(u'Test') and inspect.isclass(objet):
                if getattr(object, u'disabled', False):
                    continue
                minimum = None
                for _, method in inspect.getmembers(objet, inspect.ismethod):
                    number = method.im_func.func_code.co_firstlineno
                    if minimum is None or number < minimum:
                        minimum = number
                if minimum is not None:
                    collection.append((minimum, name, objet, False))
            elif name.startswith(u'test_') and inspect.isfunction(objet):
                code = objet.func_code
                collection.append((code.co_firstlineno, name, objet,
                                   bool(code.co_flags & 32)))
        if not collection:
            return
        self.delayed_setup_module = None
        self.did_tests_in_module = False
        if hasattr(module, u'setup_module'):
            self.delayed_setup_module = module.setup_module, module
        for _, name, objet, generator in sorted(collection):
            self.delayed_setup_class = None
            self.did_tests_in_class = False
            if inspect.isclass(objet):
                if not getattr(object, u'disabled', False):
                    self.handle_class(prefix + u'/' + name, objet)
            else:
                self.handle_function(prefix + u'/' + name, objet,
                                     generator, None)
        if self.did_tests_in_module and hasattr(module, u'teardown_module'):
            module.teardown_module(module)

    def handle_class(self, prefix, classe):
        collection = []
        for name, method in inspect.getmembers(classe, inspect.ismethod):
            if name.startswith(u'test_'):
                code = method.im_func.func_code
                collection.append((code.co_firstlineno, name, method,
                                   bool(code.co_flags & 32)))
        if not collection:
            return
        # FIXME: Should likely do module setup here!
        instance = classe()
        if hasattr(instance, u'setup_class'):
            self.delayed_setup_module = instance.setup_class, classe
        for _, name, method, generator in sorted(collection):
            self.handle_function(prefix + u'/' + name, getattr(instance, name),
                                 generator, instance)
        if self.did_tests_in_class and hasattr(instance, u'teardown_class'):
            instance.teardown_class(classe)

    def handle_function(self, prefix, function, generator, instance):
        collection = []
        if generator:
            # FIXME: Should likely do class setup here.
            try:
                for counter, arguments in enumerate(function()):
                    collection.append((prefix + u'/' + unicode(counter + 1),
                                       arguments[0], arguments[1:]))
            except Skipped:
                return
        else:
            collection.append((prefix, function, ()))
        for prefix, function, arguments in collection:
            self.launch_test(prefix, function, arguments, instance)

    def launch_test(self, prefix, function, arguments, instance):
        # Check if this test should be retained.
        if (self.exclusion is not None and self.exclusion.search(prefix)
              or self.pattern is not None and not self.pattern.search(prefix)
              or self.ordinals and self.total_count+1 not in self.ordinals):
            self.mark_progression(prefix, None)
            return
        # This test should definitely be executed.
        if self.delayed_setup_module is not None:
            self.delayed_setup_module[0](self.delayed_setup_module[1])
            self.delayed_setup_module = None
        if self.delayed_setup_class is not None:
            self.delayed_setup_class[0](self.delayed_setup_class[1])
            self.delayed_setup_class = None
        if instance is not None and hasattr(instance, u'setup_method'):
            instance.setup_method(function)
        if self.capture:
            saved_stdout = sys.stdout
            saved_stderr = sys.stderr
            stdout = sys.stdout = StringIO()
            stderr = sys.stderr = StringIO()
        self.activate_profiling()
        try:
            try:
                function(*arguments)
            except Exit:
                success = False
                raise
            except Failed:
                success = False
            except Skipped:
                success = None
            except KeyboardInterrupt:
                success = None
                raise Exit("Interruption!")
            except:
                success = False
            else:
                success = True
        finally:
            self.deactivate_profiling()
            if self.capture:
                sys.stdout = saved_stdout
                sys.stderr = saved_stderr
                stdout = stdout.getvalue()
                stderr = stderr.getvalue()
            else:
                stdout = None
                stderr = None
            self.mark_progression(prefix, success)
            if success is False:
                if self.save:
                    self.failures.append(self.total_count)
                else:
                    tracing = StringIO()
                    traceback.print_exc(file=tracing)
                    self.failures.append(
                        (self.total_count, prefix, function, arguments,
                         stdout, stderr, str(tracing.getvalue())))
            if instance is not None and hasattr(instance, u'teardown_method'):
                instance.teardown_method(function)
            self.did_tests_in_class = True
            self.did_tests_in_module = True
        if success is not None and len(self.failures) == self.limit:
            raise Limit_Reached

    def mark_progression(self, prefix, succes):
        self.total_count += 1
        if succes is None:
            self.skip_count += 1
        else:
            write = sys.stderr.write
            if self.verbose:
                write(u'%5d. [%s] %s\n' % (self.total_count, prefix,
                                           (u'FAILED', u'ok')[succes]))
            else:
                if self.column == WIDTH:
                    write(u'\n')
                    self.column = 0
                if not self.column:
                    if self.counter:
                        text = u'%5d ' % (self.counter + 1)
                    else:
                        text = self.identifier + u' '
                    write(text)
                    self.column = len(text)
                write(u'E·'[succes])
                self.column += 1
                self.counter += 1

    def activate_profiling(self):
        if self.profiler is not None:
            self.profiler.enable(subcalls=True, builtins=True)

    def deactivate_profiling(self):
        if self.profiler is not None:
            self.profiler.disable()

class Friendly_StreamWriter:
    # Avoid some Unicode nightmares, by allowing both unicode and
    # UTF-8 str strings to be written (given our sources are all UTF-8).

    def __init__(self, stream):
        import codecs, locale
        writer = codecs.getwriter(locale.getpreferredencoding())
        self.stream = writer(stream, 'backslashreplace')

    def write(self, text):
        if not isinstance(text, unicode):
            text = unicode(text, 'UTF-8')
        self.stream.write(text)

    def writelines(self, lines):
        for line in lines:
            self.write(line)

run = Main()
main = run.main

class Exit(Exception): pass

class Failed(Exception): pass

class NotRaised(Exception): pass

class Skipped(Exception): pass

class py:

    class test:

        @staticmethod
        def exit(message="Unknown reason"):
            raise Exit(message)

        @staticmethod
        def fail(message="Unknown reason"):
            raise Failed(message)

        @staticmethod
        def skip(message="Unknown reason"):
            raise Skipped(message)

        @staticmethod
        def raises(expected, *args, **kws):
            try:
                if isinstance(args[0], unicode) and not kws:
                    eval(args[0])
                else:
                    args[0](*args[1:], **kws)
            except expected:
                return
            else:
                raise NotRaised(u"Exception did not happen.")

if __name__ == u'__main__':
    main(*sys.argv[1:])
