#!/usr/bin/python3
# generate debian/copyright from debian/copyright.template and node_modules
# Author: Martin Pitt <mpitt@debian.org>
#         Allison Karlitskaya <allison.karlitskaya@redhat.com>

import argparse
import gzip
import os
import re
import sys
import time

BASE_DIR = os.path.realpath(f'{__file__}/../..')
TEMPLATE_FILE = f'{BASE_DIR}/tools/debian/copyright.template'


own_copyright = f"Copyright (C) 2013 - {time.strftime('%Y')} Red Hat, Inc."

license_patterns = {
    # Common patterns
    r'\bMIT\b': ['MIT'],

    # https://github.com/focus-trap/focus-trap/blob/master/LICENSE
    r'\bfocus-trap\b': ['MIT'],
}

copyright_patterns = {
    # Common patterns
    r'Copyright (.*)$': [r'\1'],
    r'@copyright (.*)$': [r'\1'],
    r'\(c\) (.*)$': [r'\1'],

    # https://github.com/focus-trap/focus-trap/blob/master/LICENSE
    r'\bfocus-trap\b': ['2015-2016 David Clark'],
}

used_patterns = set()


def parse_args():
    p = argparse.ArgumentParser(description='Generate debian/copyright file from template and node_modules')
    return p.parse_args()


def template_licenses(template):
    """Return set of existing License: short names"""

    ids = set()
    for line in template.splitlines():
        if line.startswith('License:'):
            ids.add(line.split(None, 1)[1].lower())
    return ids


def find_patterns(patterns, text):
    results = set()

    for pattern, templates in patterns.items():
        for match in re.finditer(pattern, text, re.MULTILINE):
            used_patterns.add(pattern)
            results.update(match.expand(template) for template in templates)

    return results

#
# main
#


args = parse_args()

with open(TEMPLATE_FILE, encoding='UTF-8') as f:
    template = f.read()

license_ids = template_licenses(template)

# scan dist/ bundles for third-party copyrights and licenses

dist_copyrights = {}  # Files: dirglob → set(copyrights)
dist_licenses = {}  # Files: dirglob → set(licenses)

for directory, _subdirs, files in os.walk(f'{BASE_DIR}/dist'):
    for file in files:
        if '.LEGAL.txt' not in file:
            continue

        full_filename = os.path.join(directory, file)
        directory_glob = os.path.relpath(directory, start=BASE_DIR) + '/*'

        if file.endswith('.gz'):
            with gzip.open(full_filename, 'rt') as license_file_gz:
                contents = license_file_gz.read()
        else:
            with open(full_filename, 'rt') as license_file:
                contents = license_file.read()

        for comment in contents.split('\n\n'):
            if (comment.strip() == "" or "Bundled license information:" in comment):
                continue

            licenses = find_patterns(license_patterns, comment)
            if not licenses:
                raise SystemError('Can not determine licenses of:\n%s' % comment)
            for license_id in licenses:
                if license_id.lower() not in license_ids:
                    raise KeyError('License {license_id} not found in {TEMPLATE_FILE}')

            # All bundles also contain our own code
            licenses.add("LGPL-2.1-or-later")

            dist_licenses.setdefault(directory_glob, set()).update(licenses)

            copyrights = find_patterns(copyright_patterns, comment)
            if not copyrights:
                raise SystemError('Did not find any copyrights in:\n%s' % comment)

            # All bundles also contain our own code
            copyrights.add(own_copyright)

            dist_copyrights.setdefault(directory_glob, set()).update(copyrights)

for pattern in set.union(set(license_patterns), set(copyright_patterns)):
    if pattern not in used_patterns:
        # We'll have no LEGAL.txt files in that dev builds
        # so of course we won't use any of the patterns
        if os.getenv('NODE_ENV') == 'development' or os.getenv('IGNORE_UNUSED_PATTERNS'):
            continue

        sys.exit(f'build-debian-copyright: Unused pattern: {pattern}')

paragraphs = []
for files in sorted(dist_copyrights):
    paragraphs.append("Files: {0}\nCopyright: {1}\nLicense: {2}".format(
        files,
        '\n '.join(sorted(dist_copyrights[files])),
        ' and '.join(sorted(dist_licenses[files]))))

# force UTF-8 output, even when running in C locale
for line in template.splitlines():
    if '#NPM' in line:
        sys.stdout.buffer.write('\n\n'.join(paragraphs).encode('UTF-8'))
    else:
        sys.stdout.buffer.write(line.encode('UTF-8'))
    sys.stdout.buffer.write(b'\n')
