'''
Helper functions for running TestCases.
'''

import sys
import os
from os.path import join
from unittest import TestCase, TestLoader, TestSuite, TextTestRunner

from unittestplus.testcaseplus import TestCasePlus


def _make_suite(test_case):
    'If given a suite, returns it. If given a TestCase, returns it in a suite.'
    if isinstance(test_case, TestSuite):
        return test_case
    elif isinstance(test_case, type) and issubclass(test_case, TestCase):
        return TestLoader().loadTestsFromTestCase(test_case)
    else:
        msg = "Not a test: %s, type %s" % (test_case, type(test_case))
        raise TypeError(msg)


def combine(*args):
    'Combine several TestCase classes & suite objects into a single new suite'
    all_tests = TestSuite()
    for test_case in args:
        all_tests.addTest(_make_suite(test_case))
    return all_tests


def run(test_case, verbosity=1, descriptions=1):
    'Run the given TestCase or TestSuite'
    suite = _make_suite(test_case)
    runner = TextTestRunner(verbosity=verbosity, descriptions=descriptions)
    return runner.run(suite)


def _get_modules(root):
    # walks all subdirectories looking for modules.
    # with thanks to Francois Pinard
    modules = []
    stack = [root]
    while stack:
        directory = stack.pop()
        for base in os.listdir(directory):
            name = join(directory, base)
            if os.path.isdir(name):
                is_package = os.path.isfile(join(name, '__init__.py'))
                if is_package:
                    stack.append(name)
            else:
                is_module = base.endswith('.py') and base != '__init__.py'
                if is_module:
                    modules.append(_path_to_module(name))
    return modules


def _path_to_module(path):
    if path.endswith('.py'):
        path = path[:-3]
    name = path.replace(os.sep, '.')
    __import__(name)
    return sys.modules[name]


def _is_testcase(clazz):
    return (
        isinstance(clazz, type) and
        issubclass(clazz, TestCase) and
        not clazz is TestCase and
        not clazz is TestCasePlus
    )


def _get_testcases(module):
    testcases = []
    for classname in dir(module):
        clazz = getattr(module, classname)
        if _is_testcase(clazz):
            testcases.append(clazz)
    return testcases


def get_all_tests(root):
    all_tests = []
    modules = _get_modules(root)
    for module in modules:
        for testcase in _get_testcases(module):
            all_tests.append(testcase)
    return combine(*all_tests)

        
