#!/usr/bin/python3
#
# Copyright (C) 2019 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import argparse
from collections import OrderedDict
import logging as log
import os
import shutil
import shlex
import subprocess
import sys
import tempfile


# Maximum filename length
MAX_FNAME = 253


def SplitCmdline(cmdline):
    """
    Split the commandline into an OrderedDict of arguments

    If the argument appears multiple times its entry will have a list of the values,
    and its position in the dictionary will be that of the first occurrence.
    If the argument has no value its entry will be [None]
    Arguments with a single value will be [VALUE]

    Quotes strings are maintained, and values may contain equals signs which
    will be preserved.
    """
    kernel_args = OrderedDict()
    for a in shlex.split(cmdline):
        if "=" in a:
            k,v = a.split("=", 1)
        else:
            k, v = a, None
        if k in kernel_args:
            kernel_args[k].append(v)
        else:
            kernel_args[k] = [v]

    return kernel_args


def quote(s):
    """
    Return a quoted string if it contains spaces.
    Otherwise just return the string.
    """
    if " " in s:
        return f"\"{s}\""
    return s


def ListKernelArgs(kernel_args):
    """
    Return a list of quoted kernel arguments and values

    The dictionary is the output from SplitCmdline.
    Values with spaces are quoted with ""
    """
    kernel_cmdline = []
    for k, vs in kernel_args.items():
        for v in vs:
            if v is None:
                kernel_cmdline.append(f"{k}")
            else:
                kernel_cmdline.append(f"{k}={quote(v)}")

    return kernel_cmdline


def JoinKernelArgs(kernel_args):
    """
    Return a string of quoted kernel arguments and values

    The dictionary is the output from SplitCmdline.
    Values with spaces are quoted with ""
    """
    return " ".join(ListKernelArgs(kernel_args))


def AlterKernelArgs(kernel_args, rm_args, add_args):
    """
    Modify the kernel argument OrderedDict

    This first removes all arguments listed in rm_args. It will ignore
    arguments that are not present.

    It then extends the arguments listed in the add_args dictionary if they are
    already present, otherwise it adds them to kernel_args

    It returns the modified kernel_args OrderedDict.
    """
    # Remove arguments
    for k in rm_args:
        if k in kernel_args:
            del kernel_args[k]

    # Add new arguments and update existing ones
    for k, vs in add_args.items():
        if k in kernel_args:
            kernel_args[k].extend(vs)
        else:
            kernel_args[k] = vs

    return kernel_args


def WrapKernelArgs(kernel_args, width=80):
    """
    Return a wordwrapped string

    It breaks the lines at spaces inbetween commands,
    preserving quoted strings verbatim.
    """
    result = kernel_args[0]
    cols = len(result)
    for arg in kernel_args[1:]:
        if 1 + cols + len(arg) > width:
            result += "\n"
            result += arg
            cols = len(arg)
        else:
            result += " " + arg
            cols = cols + len(arg) + 1
    result += "\n"
    return result


def GetISODetails(isopath):
    """
    Use xorriso to list the contents of the iso and get the volume id

    Metadata about the iso is output to stderr, file listing is sent to stdout.

    Returns a tuple of volume id, and the list of files on the iso
    """
    cmd = ["xorriso", "-indev", isopath, "-pkt_output", "on", "-find"]
    out =  subprocess.run(cmd, check=True, capture_output=True, env={"LANG": "C"})

    volid = ""
    for line in out.stderr.decode("utf-8").splitlines():
        if line.startswith("Volume id"):
            volid = shlex.split(line)[-1]
    if not volid:
        raise RuntimeError(f"{isopath} is missing a volume id")

    # Files on the iso
    # The files are on stdout and are prefixed with 'R:1:' for result lines with trailing newlines.
    files = []
    for line in out.stdout.decode("utf-8").splitlines():
        if line.startswith("R:1:"):
            files.append(os.path.normpath(shlex.split(line)[-1]))

    return volid, files


def ExtractISOFiles(isopath, files, tmpdir):
    """
    Extract the given files (which must exist on the iso) into the temporary
    directory.
    """
    # Make sure the user can write to any extracted directories using -chmod_r
    cmd = ["osirrox", "-indev", isopath, "-chmod_r", "u+rwx", "/", "--"]
    for f in files:
        cmd.extend(["-extract", f, tmpdir + "/" + f])
    cmd.extend(["-rollback_end"])
    subprocess.run(cmd, check=True, capture_output=False, env={"LANG": "C"})


# From pylorax/treebuilder.py
# udev whitelist: 'a-zA-Z0-9#+.:=@_-' (see is_whitelisted in libudev-util.c)
udev_blacklist=' !"$%&\'()*,/;<>?[\\]^`{|}~' # ASCII printable, minus whitelist
udev_blacklist += ''.join(chr(i) for i in range(32)) # ASCII non-printable
def udev_escape(label):
    """
    Escape the volume id label characters so they can be used on the ISO
    """
    out = ''
    for ch in label:
        out += ch if ch not in udev_blacklist else '\\x%02x' % ord(ch)
    return out


def CheckBigFiles(add_paths):
    """
    Check file size and filename length for problems

    Returns True if any file exceeds 4GiB
    Raises a RuntimeError if any filename is longer than MAX_FNAME
    This examines all the files included, so may take some time.
    """
    big_file = False
    for src in add_paths:
        if os.path.isdir(src):
            for top, dirs, files in os.walk(src):
                for f in files + dirs:
                    if len(f) > MAX_FNAME:
                        raise RuntimeError("iso contains filenames that are too long: %s" % f)
                    if os.stat(top + "/" + f).st_size >= 4*1024**3:
                        big_file = True
        else:
            if len(src) > MAX_FNAME:
                raise RuntimeError("iso contains filenames that are too long: %s" % f)
            if os.stat(src).st_size >= 4*1024**3:
                big_file = True

    return big_file


def ImplantMD5(output_iso):
    """
    Add md5 checksums to the final iso
    """
    cmd = ["implantisomd5", output_iso]
    log.debug(" ".join(cmd))
    try:
        subprocess.check_output(cmd)
    except subprocess.CalledProcessError as e:
        log.error(str(e))
        raise RuntimeError("implantisomd5 failed")


def RebuildS390CDBoot(tmpdir):
    """
    On s390x the cdboot.img needs to be rebuilt with the new cmdline arguments
    """
    # First check for the needed files
    missing = []
    for f in ["images/kernel.img", "images/initrd.img", "images/cdboot.prm"]:
        if not os.path.exists(tmpdir + "/" + f):
            log.debug("Missing file %s", f)
            missing.append(f)
    if missing:
        raise RuntimeError("Missing requirement %s" % ", ".join(missing))

    cmd = ["mk-s390image", tmpdir + "/images/kernel.img", tmpdir + "/images/cdboot.img",
           "-r", tmpdir + "/images/initrd.img",
           "-p", tmpdir + "/images/cdboot.prm"]
    log.debug(" ".join(cmd))
    try:
        subprocess.check_output(cmd)
    except subprocess.CalledProcessError as e:
        log.error(str(e))
        raise RuntimeError("Running mk-s390image")


def GetCmdline(line, commands):
    """
    Get the cmdline portion of a line that starts with a command

    returns the indented command prefix, and the cmdline.
    """
    for c in commands:
        if c not in line:
            continue

        # The command must be the first non-whitespace word
        first, _ = line.strip().split(" ", 1)
        if first != c:
            continue

        indent, prefix, cmdline = line.partition(c)
        return indent+prefix, cmdline.strip()

    return "", line


def EditIsolinux(rm_args, add_args, new_volid, old_volid, tmpdir):
    """
    Modify the cmdline for an isolinux.cfg
    Remove args, add new arguments and change existing volid if requested
    """
    orig_cfg = tmpdir + "/isolinux/isolinux.cfg"
    if not os.path.exists(orig_cfg):
        log.warning("No isolinux/isolinux.cfg file found")
        return

    change_volid = old_volid != new_volid

    # Edit the config file, save the new one as .new
    with open(orig_cfg, "r") as in_fp:
        with open(orig_cfg + ".new", "w") as out_fp:
            for line in in_fp:
                if change_volid and old_volid in line:
                    line = line.replace(old_volid, new_volid)
                prefix, cmdline = GetCmdline(line, ["append"])
                if prefix:
                    args = SplitCmdline(cmdline)
                    new_args = AlterKernelArgs(args, rm_args, add_args)
                    out_fp.write(prefix+" "+JoinKernelArgs(new_args))
                    out_fp.write("\n")
                else:
                    out_fp.write(line)
            out_fp.close()
        os.replace(orig_cfg + ".new", orig_cfg)


def EditGrub2(rm_args, add_args, new_volid, old_volid, tmpdir):
    """
    Modify the cmdline for GRUB2 UEFI and BIOS config files
    Add the new arguments and change existing volid if requested
    """
    grub_cfgs = ["EFI/BOOT/grub.cfg", "EFI/BOOT/BOOT.conf",
                 "boot/grub2/grub.cfg", "boot/grub/grub.cfg"]

    if not any(os.path.exists(tmpdir + "/" + c) for c in grub_cfgs):
        log.warning("No grub config files found")
        return

    change_volid = old_volid != new_volid

    for cfg in grub_cfgs:
        orig_cfg = tmpdir + "/" + cfg
        if not os.path.exists(orig_cfg):
            continue

        with open(orig_cfg, "r") as in_fp:
            with open(orig_cfg + ".new", "w") as out_fp:
                for line in in_fp:
                    if change_volid and old_volid in line:
                        line = line.replace(old_volid, new_volid)
                    # Some start with linux (BIOS/aarch64), others with linuxefi (x86_64)
                    prefix, cmdline = GetCmdline(line, ["linuxefi", "linux"])
                    if prefix:
                        args = SplitCmdline(cmdline)
                        new_args = AlterKernelArgs(args, rm_args, add_args)
                        out_fp.write(prefix+" "+JoinKernelArgs(new_args))
                        out_fp.write("\n")
                    else:
                        out_fp.write(line)
                out_fp.close()
        os.replace(orig_cfg + ".new", orig_cfg)


def EditS390(rm_args, add_args, new_volid, old_volid, tmpdir):
    """
    Modify the cmdline for s390 config files
    Add the new arguments and change existing volid if requested
    """
    s390_cfgs = ["images/generic.prm", "images/cdboot.prm"]

    if not any(os.path.exists(tmpdir + "/" + c) for c in s390_cfgs):
        log.warning("No s390 config files found")
        return

    change_volid = old_volid != new_volid

    for cfg in s390_cfgs:
        orig_cfg = tmpdir + "/" + cfg
        if not os.path.exists(orig_cfg):
            log.warning("No %s file found", cfg)
            continue

        with open(orig_cfg, "r") as in_fp:
            # Read the config file and turn it into a line
            lines = in_fp.readlines()

        cmdline = " ".join(l.strip() for l in lines)
        # Replace the volid
        if change_volid and old_volid in cmdline:
            cmdline = cmdline.replace(old_volid, new_volid)
        args = SplitCmdline(cmdline)
        new_args = AlterKernelArgs(args, rm_args, add_args)

        cmdline = ListKernelArgs(new_args)

        # Write the new config file, breaking at 80 columns
        with open(orig_cfg + ".new", "w") as out_fp:
            out_fp.write(WrapKernelArgs(cmdline, width=80))
            out_fp.close()
        os.replace(orig_cfg + ".new", orig_cfg)


def CheckDiscinfo(path):
    """
    If the ISO contains a .discinfo file, check the arch against the host arch

    Raises a RuntimeError if the arch does not match
    """
    ## TODO -- is this even needed with the new method of rebuilding the iso?
    if os.path.exists(path):
        with open(path) as f:
            discinfo = [l.strip() for l in f.readlines()]

        log.info("iso arch = %s", discinfo[2])
        if os.uname().machine != discinfo[2]:
            raise RuntimeError("iso arch does not match the host arch.")


def MakeKickstartISO(input_iso, output_iso, ks="", add_paths=None,
                    cmdline="", rm_args="", new_volid="", implantmd5=True):
    """
    Make a kickstart ISO from a boot.iso or dvd
    """
    if add_paths is None:
        add_paths = []

    # Gather information about the input iso
    old_volid, files = GetISODetails(input_iso)
    if not old_volid and not new_volid:
        raise RuntimeError("No volume id found, cannot create iso.")

    # Extract files that match the known config files.
    known_configs = set([".discinfo", "isolinux/isolinux.cfg",
                         "boot/grub2/grub.cfg", "boot/grub/grub.cfg",
                         "EFI/BOOT/grub.cfg", "EFI/BOOT/BOOT.conf",
                         "images/generic.prm", "images/cdboot.prm",
                         "images/kernel.img", "images/initrd.img"])
    extract_files = set(files) & known_configs
    with tempfile.TemporaryDirectory(prefix="mkksiso-") as tmpdir:
        ExtractISOFiles(input_iso, extract_files, tmpdir)
        CheckDiscinfo(tmpdir + "/.discinfo")
        new_volid = new_volid or old_volid
        log.info("Volume Id = %s", new_volid)

        # Escape the volume ids
        new_volid = udev_escape(new_volid)
        old_volid = udev_escape(old_volid)

        remove_args = rm_args.split(' ') if rm_args else []
        add_args = SplitCmdline(cmdline)

        if ks:
            add_args["inst.ks"] = ["hd:LABEL=%s:/%s" % (new_volid or old_volid, os.path.basename(ks))]
            add_paths.append(ks)

        # Add kickstart command and optionally change the volid of the available config files
        EditIsolinux(remove_args, add_args, new_volid, old_volid, tmpdir)
        EditGrub2(remove_args, add_args, new_volid, old_volid, tmpdir)
        EditS390(remove_args, add_args, new_volid, old_volid, tmpdir)

        if os.uname().machine.startswith("s390"):
            RebuildS390CDBoot(tmpdir)

        # Build the command to rebuild the iso with the changes and additions
        cmd = ["xorriso", "-indev", input_iso, "-outdev", output_iso, "-boot_image", "any", "replay"]
        if new_volid != old_volid:
            cmd.extend(["-volid", new_volid])

        # Update the config files that were extracted and modified
        for root, _, files in os.walk(tmpdir, topdown=False):
            isoroot = root.replace(tmpdir, "")
            for f in files:
                cmd.extend(["-update", root + "/" + f, isoroot + "/" + f])

        # Add the kickstart and the new files and directories
        for p in add_paths:
            cmd.extend(["-map", p, os.path.basename(p)])

        if CheckBigFiles(add_paths):
            if "-as" not in cmd:
                cmd.extend(["-as", "mkisofs"])
            cmd.extend(["-iso-level", "3"])

        if os.uname().machine.startswith("ppc64le"):
            if "-as" not in cmd:
                cmd.extend(["-as", "mkisofs"])
            cmd.extend(["-U"])

        log.debug("Running: %s", " ".join(cmd))
        subprocess.run(cmd, check=True, capture_output=False, env={"LANG": "C"})

    if implantmd5:
        ImplantMD5(output_iso)


def setup_arg_parser():
    """ Return argparse.Parser object of cmdline."""
    parser = argparse.ArgumentParser(description="Add a kickstart and files to an iso")

    parser.add_argument("-a", "--add", action="append", dest="add_paths", default=[],
                        type=os.path.abspath,
                        help="File or directory to add to ISO (may be used multiple times)")
    parser.add_argument("-c", "--cmdline", dest="cmdline", metavar="CMDLINE", default="",
                        help="Arguments to add to kernel cmdline")
    parser.add_argument("-r", "--rm-args", dest="rm_args", metavar="ARGS", default="",
                        help="Space separated list of arguments to remove from the kernel cmdline")
    parser.add_argument("--debug", action="store_const", const=log.DEBUG,
                        dest="loglevel", default=log.INFO,
                        help="print debugging info")
    parser.add_argument("--no-md5sum", action="store_false", default=True,
                        help="Do not run implantisomd5 on the ouput iso")
    parser.add_argument("--ks", type=os.path.abspath, metavar="KICKSTART",
                        help="Optional kickstart to add to the ISO")
    parser.add_argument("-V", "--volid", dest="volid", help="Set the ISO volume id, defaults to input's", default=None)

    parser.add_argument("ks_pos", nargs="?", type=os.path.abspath, metavar="KICKSTART",
                        help="Optional kickstart to add to the ISO")
    parser.add_argument("input_iso", type=os.path.abspath, help="ISO to modify")
    parser.add_argument("output_iso", type=os.path.abspath, help="Full pathname of iso to be created")

    return parser


def main():
    parser = setup_arg_parser()
    args = parser.parse_args()
    log.basicConfig(format='%(levelname)s:%(message)s', level=args.loglevel)

    try:
        errors = False
        for t in ["xorriso", "osirrox"]:
            if not shutil.which(t):
                log.error("%s binary is missing", t)
                errors = True

        files = [args.input_iso, *args.add_paths]
        if args.ks or args.ks_pos:
            files += [args.ks or args.ks_pos]
        for f in files:
            if not os.path.exists(f):
                log.error("%s is missing", f)
                errors = True

        if os.path.exists(args.output_iso):
            log.error("%s already exists", args.output_iso)
            errors = True

        if "=" in args.rm_args:
            log.error("--rm-args should only list the arguments to remove, not values")
            errors = True

        if args.ks and args.ks_pos:
            log.error("Use either --ks KICKSTART or positional KICKSTART but not both")
            errors = True

        if not any([args.ks or args.ks_pos, args.add_paths, args.cmdline, args.rm_args, args.volid]):
            log.error("Nothing to do - pass one or more of --ks, --add, --cmdline, --rm-args, --volid")
            errors = True

        if errors:
            raise RuntimeError("Problems running %s" % sys.argv[0])

        MakeKickstartISO(args.input_iso, args.output_iso, args.ks or args.ks_pos,
                         args.add_paths, args.cmdline, args.rm_args,
                         args.volid, args.no_md5sum)
    except RuntimeError as e:
        log.error(str(e))
        return 1

    return 0


if __name__ == '__main__':
    sys.exit(main())
