#!/usr/bin/env python3
from functools import cache
from pathlib import Path
from tempfile import mkstemp
from typing import Iterable
import argparse
import re
import subprocess
import sys
from util import EXTENSIONS_TO_CHECK


COPYRIGHT_LINE_RE = re.compile(r"^// Copyright 2\d\d\d Signal Messenger, LLC\n$")
SPDX_LINE = "// SPDX-License-Identifier: AGPL-3.0-only\n"


def read_first_lines(line_count: int, path: Path) -> list[str]:
    if line_count == 0:
        return ""
    result = []
    with open(path, "rt", encoding="utf8") as file:
        for line in file:
            result.append(line)
            if len(result) >= line_count:
                break
    return result


def has_shebang(line: str) -> bool:
    return line.startswith("#!")


def has_swift_tools_version(line: str) -> bool:
    return line.startswith("// swift-tools-version:")


def is_first_line_important(line: str) -> bool:
    return has_shebang(line) or has_swift_tools_version(line)


def has_valid_license_header(path: Path) -> bool:
    first_six_lines = read_first_lines(6, path)
    if is_first_line_important(first_six_lines[0]):
        lines_to_consider = first_six_lines[1:]
    else:
        lines_to_consider = first_six_lines[:-1]
    return (
        (lines_to_consider[0] == "//\n")
        and (COPYRIGHT_LINE_RE.match(lines_to_consider[1]) is not None)
        and (lines_to_consider[2] == SPDX_LINE)
        and (lines_to_consider[3] == "//\n")
        and (lines_to_consider[4] == "\n")
    )


@cache
def get_staging_file_path() -> Path:
    path_str = mkstemp(prefix="signal-ios-license-staging-file-")[1]
    return Path(path_str)


def get_creation_year_for(path: Path) -> str:
    dates_changed = subprocess.check_output(
        ["git", "log", "--follow", "--format=%as", "--date=short", path.resolve()],
        text=True,
    ).splitlines()
    return dates_changed[-1][0:4]


def lines_with_license_header_added(path: Path) -> Iterable[str]:
    with open(path, "rt", encoding="utf8") as source_file:
        for index, line in enumerate(source_file):
            if index == 0:
                if is_first_line_important(line):
                    yield line
                yield "//\n"
                yield f"// Copyright {get_creation_year_for(path)} Signal Messenger, LLC\n"
                yield SPDX_LINE
                yield "//\n"
                yield "\n"
                if not is_first_line_important(line):
                    yield line
            else:
                yield line


def add_license_header_to(path: Path) -> None:
    staging_file_path = get_staging_file_path()
    with open(staging_file_path, "wt", encoding="utf8") as staging_file:
        for line in lines_with_license_header_added(path):
            staging_file.write(line)
    staging_file_path.replace(path)


def main() -> None:
    parser = argparse.ArgumentParser(
        description="Check license headers across the project"
    )
    parser.add_argument("path", nargs="*", help="A path to process")
    parser.add_argument(
        "--fix", action="store_true", help="Attempt to auto-add license headers"
    )
    ns = parser.parse_args()
    should_fix = ns.fix

    all_good = True
    for path in ns.path:
        path = Path(path)
        if not has_valid_license_header(path):
            if should_fix:
                add_license_header_to(path)
            else:
                print(f"{path} has an invalid license header", file=sys.stderr)
                all_good = False

    if not all_good:
        sys.exit(1)


if __name__ == "__main__":
    main()
