#!/usr/bin/env python3
#
# Copyright (C) 2018 Jonathan Lebon <jlebon@redhat.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this library; if not, write to the
# Free Software Foundation, Inc., 59 Temple Place - Suite 330,
# Boston, MA 02111-1307, USA.

"""
    This is a small helper CLI tool to create and modify updateinfo.xml data.
    Note that the original createrepo_c API is geared towards creating the XML
    in a single pass (e.g. Bodhi), so there are inefficiencies in the way we
    modify data below (e.g. removing data involves copying). Our goal here is
    to make it *really* easy to use from the command-line for testing purposes.
"""

import os
import sys
import argparse

from collections import namedtuple

import rpm
import createrepo_c as cr


def main():
    args = parse_args()
    args.func(args)


def parse_args():

    parser = argparse.ArgumentParser()
    parser.add_argument('--repo', help='rpmmd repo path', default=os.getcwd())
    subparsers = parser.add_subparsers(dest='cmd', title='subcommands')
    subparsers.required = True

    add = subparsers.add_parser('add')
    add.set_defaults(func=cmd_add)
    add.add_argument('--if-not-exists', action='store_true')
    add.add_argument('id')
    add.add_argument('type', default='security', nargs='?',
                     choices=['bugfix', 'enhancement', 'security'])
    add.add_argument('severity', default='none', nargs='?',
                     choices=['none', 'low', 'moderate',
                              'important', 'critical'])

    delete = subparsers.add_parser('delete')
    delete.set_defaults(func=cmd_delete)
    add.add_argument('--if-exists', action='store_true')
    delete.add_argument('id')

    show = subparsers.add_parser('show')
    show.set_defaults(func=cmd_show)
    show.add_argument('id', nargs='?')

    add_pkg = subparsers.add_parser('add-pkg')
    add_pkg.set_defaults(func=cmd_add_pkg)
    add_pkg.add_argument('id')
    add_pkg.add_argument('name')
    add_pkg.add_argument('epoch', nargs='?', type=int)
    add_pkg.add_argument('version', nargs='?')
    add_pkg.add_argument('release', nargs='?')
    add_pkg.add_argument('arch', nargs='?')

    delete_pkg = subparsers.add_parser('delete-pkg')
    delete_pkg.set_defaults(func=cmd_delete_pkg)
    delete_pkg.add_argument('id')
    delete_pkg.add_argument('name_or_nevra')

    add_ref = subparsers.add_parser('add-ref')
    add_ref.set_defaults(func=cmd_add_ref)
    add_ref.add_argument('id')
    add_ref.add_argument('refid')
    add_ref.add_argument('url')
    add_ref.add_argument('title')

    delete_ref = subparsers.add_parser('delete-ref')
    delete_ref.set_defaults(func=cmd_delete_ref)
    delete_ref.add_argument('id')
    delete_ref.add_argument('refid')

    return parser.parse_args()


def cmd_add(args):

    uinfo = get_updateinfo(args.repo)

    try:
        uinfo = add_update(uinfo, args.id, args.type, args.severity)
    except UpdateExistsError:
        if args.if_not_exists:
            return
        raise

    set_updateinfo(args.repo, uinfo)


def cmd_show(args):
    uinfo = get_updateinfo(args.repo)
    show_updates(uinfo, args.id)


def cmd_delete(args):

    uinfo = get_updateinfo(args.repo)

    try:
        uinfo = delete_update(uinfo, args.id)
    except UpdateNotExistsError:
        if args.if_exists:
            return
        raise

    set_updateinfo(args.repo, uinfo)


class Nevra(namedtuple('Nevra', 'name epoch version release arch')):

    def __str__(self):
        if self.epoch is not None and self.epoch > 0:
            return '%s-%u:%s-%s.%s' % (self.name, self.epoch, self.version,
                                       self.release, self.arch)
        return '%s-%s-%s.%s' % (self.name, self.version,
                                self.release, self.arch)

    # generic equals check for cr.UpdateCollectionPackage/cr.Package objects
    def equals_pkg(self, pkg):
        for attr in ['name', 'epoch', 'version', 'release', 'arch']:
            if getattr(self, attr) != getattr(pkg, attr):
                return False
        return True


def nevra_from_rpm(rpm_filename):

    ts = rpm.TransactionSet()
    with open(rpm_filename) as f:
        hdr = ts.hdrFromFdno(f)

    return Nevra(hdr[rpm.RPMTAG_NAME],
                 hdr[rpm.RPMTAG_EPOCH],
                 hdr[rpm.RPMTAG_VERSION],
                 hdr[rpm.RPMTAG_RELEASE],
                 hdr[rpm.RPMTAG_ARCH])


def nevra_from_obj(obj):
    epoch = 0
    if obj.epoch is not None:
        epoch = int(obj.epoch)
    return Nevra(obj.name, epoch, obj.version, obj.release, obj.arch)


def nevra_from_primary(repo, name):
    repomd = cr.Repomd(repomd_xml(repo))

    pkgs = []

    def pkgcb(pkg):
        if pkg.name == name:
            pkgs.append(pkg)

    for record in repomd.records:
        if record.type == 'primary':
            primary_xml = os.path.join(repo, record.location_href)
            cr.xml_parse_primary(primary_xml, do_files=False, pkgcb=pkgcb)
            break

    if len(pkgs) == 0:
        raise Exception("Package '%s' not found" % name)
    if len(pkgs) > 1:
        raise Exception("Multiple packages found for '%s'" % name)
    return nevra_from_obj(pkgs[0])


def cmd_add_pkg(args):
    uinfo = get_updateinfo(args.repo)
    if args.epoch is None:  # user only passed the 'name' param
        # we support passing an RPM file from which to extract fields
        if os.path.isfile(args.name):
            nevra = nevra_from_rpm(args.name)
        else:
            # try to find pkg in primary
            nevra = nevra_from_primary(args.repo, args.name)
    else:
        nevra = nevra_from_obj(args)
    uinfo = add_pkg_to_update(uinfo, args.id, nevra)
    set_updateinfo(args.repo, uinfo)


def cmd_delete_pkg(args):
    uinfo = get_updateinfo(args.repo)
    uinfo = delete_pkg_from_update(uinfo, args.id, args.name_or_nevra)
    set_updateinfo(args.repo, uinfo)


def cmd_add_ref(args):
    uinfo = get_updateinfo(args.repo)
    uinfo = add_ref_to_update(uinfo, args.id,
                              (args.refid, args.url, args.title))
    set_updateinfo(args.repo, uinfo)


def cmd_delete_ref(args):
    uinfo = get_updateinfo(args.repo)
    uinfo = delete_ref_from_update(uinfo, args.id, args.refid)
    set_updateinfo(args.repo, uinfo)


def get_updateinfo(repo):
    '''
        Parse existing repo updateinfo.xml or create a new one.
    '''
    repomd = cr.Repomd(repomd_xml(repo))
    for record in repomd.records:
        if record.type == 'updateinfo':
            return cr.UpdateInfo(os.path.join(repo, record.location_href))
    return cr.UpdateInfo()


def repomd_xml(repo):
    return os.path.join(repo, "repodata/repomd.xml")


def sev2xml(sev):
    # important -> Important, which is what yum/dnf expects
    return sev[0].upper() + sev[1:]


class UpdateExistsError(Exception):
    pass


class UpdateNotExistsError(Exception):
    pass


def add_update(uinfo, uid, utype, severity):

    # check that the target id doesn't already exist
    for update in uinfo.updates:
        if update.id == id:
            raise UpdateExistsError("Update '%s' already exists" % id)

    rec = cr.UpdateRecord()
    rec.id = uid
    rec.type = utype
    rec.severity = sev2xml(severity)
    uinfo.append(rec)
    return uinfo


def modify_update(uinfo, uid, func, func_data=None):

    # createrepo_c doesn't allow modifying the original object
    new_uinfo = cr.UpdateInfo()

    found = False
    for update in uinfo.updates:
        if update.id != uid:
            new_uinfo.append(update)
        else:
            found = True
            if func is not None:
                new_update = func(uid, update, func_data)
                if new_update is not None:
                    new_uinfo.append(new_update)
    if not found:
        raise Exception("Update '%s' does not exist" % uid)
    return new_uinfo


def delete_update(uinfo, uid):
    return modify_update(uinfo, uid, None)


def show_updates(uinfo, uid):
    for update in uinfo.updates:
        if uid is not None and update.id != uid:
            continue
        print(update.id, update.type, update.severity)
        for col in update.collections:
            for pkg in col.packages:
                print("  pkg:", pkg.filename)
        for ref in update.references:
            print("  ref:", ref.id, ref.href)


def shallow_clone_update(update):
    # the default copy() also copies collections and refs, and then there's no
    # API to *remove* from those lists, so instead we re-create from scratch
    # each time we need to modify them
    new_update = cr.UpdateRecord()
    new_update.id = update.id
    new_update.type = update.type
    new_update.severity = update.severity
    return new_update


def add_pkg_to_update_cb(uid, update, nevra):

    if len(update.collections) > 1:
        # let's just pretend that never happens for our purposes
        raise Exception("Update '%s' has more than one collection" % uid)
    elif len(update.collections) == 1:
        col = update.collections[0]
    else:
        col = cr.UpdateCollection()

    new_update = shallow_clone_update(update)
    new_col = cr.UpdateCollection()

    for pkg in col.packages:
        if nevra.equals_pkg(pkg):
            raise Exception("Update '%s' already contains pkg '%s'" %
                            (uid, nevra))
        new_col.append(pkg)

    pkg = cr.UpdateCollectionPackage()
    pkg.name = nevra.name
    pkg.epoch = '0'
    if nevra.epoch is not None:
        pkg.epoch = str(nevra.epoch)
    pkg.version = nevra.version
    pkg.release = nevra.release
    pkg.arch = nevra.arch
    pkg.filename = str(nevra) + ".rpm"
    new_col.append(pkg)

    new_update.append_collection(new_col)

    # and copy over the references
    for ref in update.references:
        new_update.append_reference(ref)

    return new_update


def add_pkg_to_update(uinfo, uid, nevra):
    return modify_update(uinfo, uid, add_pkg_to_update_cb, nevra)


def delete_pkg_from_update_cb(uid, update, name_or_nevra):

    if len(update.collections) > 1:
        # let's just pretend that never happens for our purposes
        raise Exception("Update '%s' has more than one collection" % uid)
    elif len(update.collections) == 1:
        col = update.collections[0]
    else:
        col = cr.UpdateCollection()

    new_update = shallow_clone_update(update)
    new_col = cr.UpdateCollection()

    found = False
    # just compare by filename to make it easier
    rpm_name = name_or_nevra + '.rpm'
    for pkg in col.packages:
        if pkg.filename != rpm_name and pkg.name != name_or_nevra:
            new_col.append(pkg)
        else:
            found = True
    if not found:
        raise Exception("Update '%s' does not have package '%s'" %
                        (uid, name_or_nevra))

    if len(new_col.packages) > 0:
        new_update.append_collection(new_col)

    # and copy over the references
    for ref in update.references:
        new_update.append_reference(ref)

    return new_update


def delete_pkg_from_update(uinfo, uid, name_or_nevra):
    return modify_update(uinfo, uid, delete_pkg_from_update_cb, name_or_nevra)


def add_ref_to_update_cb(uid, update, ref_tuple):
    new_update = shallow_clone_update(update)
    new_ref = cr.UpdateReference()
    new_ref.type = "bugzilla"
    (new_ref.id, new_ref.href, new_ref.title) = ref_tuple
    new_update.append_reference(new_ref)

    for ref in update.references:
        if ref.id == new_ref.id:
            raise Exception("Update '%s' already contains ref '%s'" %
                            (uid, ref.id))
        new_update.append_reference(ref)

    # and copy over the collections
    for col in update.collections:
        new_update.append_collection(col)

    return new_update


def add_ref_to_update(uinfo, uid, ref_tuple):
    return modify_update(uinfo, uid, add_ref_to_update_cb, ref_tuple)


def delete_ref_from_update_cb(uid, update, refid):
    new_update = shallow_clone_update(update)

    found = False
    for ref in update.references:
        if ref.id == refid:
            found = True
        else:
            new_update.append_reference(ref)

    if not found:
        raise Exception("Update '%s' does not have ref '%s'" %
                        (uid, refid))

    # and copy over the collections
    for col in update.collections:
        new_update.append_collection(col)

    return new_update


def delete_ref_from_update(uinfo, uid, refid):
    return modify_update(uinfo, uid, delete_ref_from_update_cb, refid)


def new_updateinfo_record(repo, uinfo):

    xml = os.path.join(repo, "repodata/updateinfo.xml.gz")
    with open(xml, 'w') as f:
        f.write(uinfo.xml_dump())

    # calculate SHA256 and rename
    ui_rec = cr.RepomdRecord('updateinfo', xml)
    ui_rec.fill(cr.SHA256)
    ui_rec.rename_file()
    return ui_rec


def set_updateinfo(repo, uinfo):
    repomd = cr.Repomd(repomd_xml(repo))

    # clone
    new_repomd = cr.Repomd()
    for record in repomd.records:
        if record.type != 'updateinfo':
            new_repomd.set_record(record)
        else:
            os.unlink(os.path.join(repo, record.location_href))

    uinfo_rec = new_updateinfo_record(repo, uinfo)
    new_repomd.set_record(uinfo_rec)

    with open(repomd_xml(repo), 'w') as f:
        f.write(new_repomd.xml_dump())


if __name__ == '__main__':
    sys.exit(main())
