#!/usr/bin/env python3

import argparse
from e3.testsuite.report.index import ReportIndex
from e3.testsuite.result import TestStatus
import os
import os.path
import subprocess
import tempfile
import yaml

descr = """This script will update the session information of a test, and then
rerun the test. It will run the test twice, so it takes some time.  """

curdir = os.getcwd()


def parse_options():
    args = None
    parser = argparse.ArgumentParser(description=descr)
    parser.add_argument(
        "testnames",
        metavar="testnames",
        nargs="*",
        help="session of these tests will be updated",
    )
    parser.add_argument(
        "--rewrite",
        dest="rewrite",
        action="store_true",
        help="Use rewrite option to update test baselines",
    )
    parser.add_argument(
        "--failed-replay-tests",
        dest="failed_replay_tests",
        action="store_true",
        help="add failed replay tests to list of tests to rerun",
    )
    parser.add_argument(
        "--skip-rerun-tests",
        dest="skip_rerun_tests",
        action="store_true",
        help="Rerun tests after updating sessions",
    )
    parser.add_argument(
        "-j",
        dest="procs",
        type=int,
        help="Number of cores to use to recreate session (default: all)",
        default=0,
    )
    args = parser.parse_args()
    return args


def testdir(testname):
    testdir = os.path.join("tests", testname)
    if not os.path.exists(testdir):
        testdir = os.path.join("internal", testname)
    if not os.path.exists(testdir):
        testdir = os.path.join("sparklib", testname)
    if not os.path.exists(testdir):
        raise ValueError
    return testdir


def is_replaytest(testname):
    d = testdir(testname)
    testyaml = os.path.join(d, "test.yaml")
    if os.path.isfile(testyaml):
        with open(testyaml, "r") as f:
            data = yaml.safe_load(f)
        if "replay" in data and data["replay"]:
            return True
    testpy = os.path.join(d, "test.py")
    if os.path.isfile(testpy):
        with open(testpy, "r") as f:
            content = f.read()
        if "def replay():" in content:
            return True
    return False


def failed_replay_tests():
    result = []
    index = ReportIndex.read(os.path.join("out", "new"))
    for entry in index.entries.values():
        if entry.status == TestStatus.FAIL and is_replaytest(entry.test_name):
            result.append(entry.test_name)
    return result


def replay_test(args, testname):
    d = testdir(testname)
    try:
        print("switching to", d)
        os.chdir(d)
        subprocess.run(["../../lib/python/replay.py", "-j", str(args.procs)])
    finally:
        os.chdir(curdir)


def run_all_tests_again(args):
    print("running tests again")
    fd, tmpfile = tempfile.mkstemp()
    with os.fdopen(fd, "w") as f:
        for name in args.testnames:
            f.write(name + "\n")
    try:
        run_tests_cmd = [
            "./run-tests",
            "--disc",
            "large",
            "--diffs",
            "--testlist",
            tmpfile,
        ]
        if args.rewrite:
            run_tests_cmd.append("--rewrite")
        subprocess.run(run_tests_cmd)
    finally:
        os.remove(tmpfile)


def main():
    args = parse_options()
    testnames = args.testnames
    if args.failed_replay_tests:
        testnames += failed_replay_tests()
    if len(testnames) == 0:
        print("no tests to replay")
        exit(1)
    for testname in testnames:
        replay_test(args, testname)
    if not args.skip_rerun_tests:
        run_all_tests_again(args)


main()
