#!/usr/bin/python3

# pylint: disable=too-many-lines

"""Manifest-Pre-Processor

This manifest-pre-processor takes a path to a manifest, loads it,
runs various pre-processing options and then produces a resultant manfest, written
to a specified filename (or stdout if filename is "-").

Manifest format version "1" and "2" are supported.

Pipeline Import:

This tool imports a pipeline from another file and inserts it into a manifest
at the same position the import instruction is located. Sources from the
imported manifest are merged with the existing sources.

The parameters for this pre-processor for format version "1" look like this:

```
...
    "mpp-import-pipeline": {
      "path": "./manifest.json"
    }
...
```

The parameters for this pre-processor for format version "2" look like this:

```
...
    "mpp-import-pipeline": {
      "path": "./manifest.json",
      "id:" "build"
    }
...
```

Version "2" also supports including multiple (or all) pipelines from a manifest:

```
...
    "mpp-import-pipelines": {
      "path": "./manifest2.json",
    }
...
```
```
...
    "mpp-import-pipelines": {
      "path": "./manifest3.json",
      "ids:" ["build", "image"]
    }
...
```



Depsolving:

This tool adjusts the `org.osbuild.rpm` stage. It consumes the `mpp-depsolve`
option and produces a package-list and source-entries.

It supports version "1" and version "2" of the manifest description format.

The parameters for this pre-processor, version "1", look like this:

```
...
    {
      "name": "org.osbuild.rpm",
      ...
      "options": {
        ...
        "mpp-depsolve": {
          "architecture": "x86_64",
          "module-platform-id": "f32",
          "solver": "dnf",
          "baseurl": "http://mirrors.kernel.org/fedora/releases/32/Everything/x86_64/os",
          "repos": [
            {
              "id": "default",
              "metalink": "https://mirrors.fedoraproject.org/metalink?repo=fedora-32&arch=$basearch"
            }
          ],
          "packages": [
            "@core",
            "dracut-config-generic",
            "grub2-pc",
            "kernel"
          ],
          "excludes": [
            (optional excludes)
          ]
        }
      }
    }
...
```

The parameters for this pre-processor, version "2", look like this:

```
...
    {
      "name": "org.osbuild.rpm",
      ...
      "inputs": {
        packages: {
          "mpp-depsolve": {
              see above
          }
        }
      }
    }
...
```


Container resolving:

This tool adjusts the `org.osbuild.skopeo` stage. It consumes the `mpp-resolve-images`
option and produces image digests and source-entries.

It supports version version "2" of the manifest description format.

The parameters for this pre-processor, version "2", look like this:

```
...
    {
      "name": "org.osbuild.skopeo",
      ...
      "inputs": {
        "images": {
          "mpp-resolve-images": {
            "images": [
              {
                 "source": "docker.io/library/ubuntu",
                 "name": "localhost/myimagename"
              },
              {
                 "source": "quay.io/centos/centos",
                 "tag": "centos7",
              }
            ]
          }
        }
      }
    }
...
```

The "source" key is required and specifies where to get the image.
Optional keys "tag" and "digest" allow specifying a particular version
of the image, otherwise the "latest" tag is used. If "name" is specified
that is used as the custom name for the container when installed.


Variable expansion and substitution:

The variables can be set in the mpp-vars toplevel dict (which is removed from
the final results) or overridden by the -D,--define commandline option.
They can then be used from within the manifest via f-string formatting using
the `mpp-format-{int,string,json}` directives. You can also use `mpp-eval`
directive to just eval an expression with the variable. Additionally the variables
will be substituted via template string substitution a la `$variable` inside
the mpp blocks.


Example:


```
    {
      "mpp-vars": {
        "variable": "some string",
        "rootfs_size": 20480,
        "arch:": "x86_64",
        "ref": "fedora/$arch/osbuild",
        "some_keys": { "a": True, "b": "$ref" }
     },
...
    {
      "foo": "a value",
      "bar": { "mpp-format-string": "This expands {variable} but can also eval like {variable.upper()}" }
      "disk_size": { "mpp-format-int": "{rootfs_size * 512}" }
      "details": { "mpp-eval": "some_keys" }
    }
...
```

Optional parts:

Similar to mpp-eval there is mpp-if, which also runs the code specified in the value, but
rather than inserting the return value it uses it as a boolean to select the return
value from the "then" (when true) or the "else" (when false) keys. If said key is not set
the entire if not is removed from the manifest.


Example:


```
    {
      "mpp-if": "arch == 'x86_64'"
       "then": {
         "key1: "value1"
       },
       "else": {
         "key1: "value2"
       }
     },
...
     "foo": {
        "key1": "val1"
        "key2": { "mpp-if": "arch == 'aarch64'" "then": "key2-special" }
     },
```


Defining partition layouts for disk images:

It is possbile to define a partition layout via `mpp-define-image`. The defined layout
is actually written to a temporary sparse file and read back via `sfdisk`, so that all
partition data like `size` and `start` include actual padding and such. The `image`
variable will be defined with `size` and `layout` keys, the latter containing the
partition layout data. It can be accessed via the "String expansion" explained above.

Example:

```
...
    "mpp-define-image": {
      "size": "10737418240",
      "table": {
      "uuid": "D209C89E-EA5E-4FBD-B161-B461CCE297E0",
      "label": "gpt",
      "partitions": [
        {
          "id": "bios-boot",
          "start": 2048,
          "size": 2048,
          "type": "21686148-6449-6E6F-744E-656564454649",
          "bootable": true,
          "uuid": "FAC7F1FB-3E8D-4137-A512-961DE09A5549"
          "attrs": [ 60 ]
        },
        ...
    }
...
```

Embedding data and files so they can be used in inputs:

This directive allows to generate `org.osbuild.inline` and `org.osbuild.curl`
sources on the fly. `org.osbuild.inline` sources can be generated by reading
a file (via the `path` parameter) or by directly providing the data (via the `text` parameter).
`org.osbuild.curl` resources can be generated by fetching a public URL (via the `url` parameter)
The reference to the inline source will be added to the array of references of the
corresponding input. Any JSON specified via the `options` parameter will be passed
as value for the reference. Additionally, a dictionary called `embedded` will be
created and within a mapping from the `id` to the checksum so that the source can
be used in e.g. `mpp-format-string` directvies.

Example:

```
...
    stages": [
        {
          "type": "org.osbuild.copy",
          "inputs": {
            "inlinefile": {
              "type": "org.osbuild.files",
              "origin": "org.osbuild.source",
              "mpp-embed": {
                "id": "hw",
                "text": "Hallo Welt\n"
              },
              "references": {
                ...
              }
            }
          },
          "options": {
            "paths": [
              {
                "from": {
                    "mpp-format-string": "input://inlinefile/{embedded['hw']}"
                },
                "to": "tree:///testfile"
              }
            ]
          }
        }
      ]
...
```

"""


import argparse
import base64
import collections
import contextlib
import hashlib
import json
import os
import pathlib
import string
import subprocess
import sys
import tempfile
import urllib.parse
import urllib.request
from typing import Dict, List, Optional

import dnf
import hawkey
import rpm
import yaml

from osbuild.util import containers
from osbuild.util.rhsm import Subscriptions

# We need to resolve an image name to a resolved image manifest digest
# and the corresponding container id (which is the digest of the config object).
# However, it turns out that skopeo is not very useful to do this, as
# can be seen in https://github.com/containers/skopeo/issues/1554
# So, we have to fall back to "skopeo inspect --raw" and actually look
# at the manifest contents.


class ImageManifest:
    # We hardcode this to what skopeo/fedora does since we don't want to
    # depend on host specific cpu details for image resolving
    _arch_from_rpm = {
        "x86_64": "amd64",
        "aarch64": "arm64",
        "armhfp": "arm"
    }
    _default_variant = {
        "arm64": "v8",
        "arm": "v7",
    }

    @staticmethod
    def arch_from_rpm(rpm_arch):
        return ImageManifest._arch_from_rpm.get(rpm_arch, rpm_arch)

    @staticmethod
    def load(imagename, tag=None, digest=None):
        if digest:
            src = f"docker://{imagename}@{digest}"
        elif tag:
            src = f"docker://{imagename}:{tag}"
        else:
            src = f"docker://{imagename}"

        res = subprocess.run(["skopeo", "inspect", "--raw", src],
                             stdout=subprocess.PIPE,
                             check=True)
        m = ImageManifest(res.stdout)
        m.name = imagename
        m.tag = tag
        m.source_digest = digest

        return m

    def __init__(self, raw_manifest):
        self.name = None
        self.tag = None
        self.source_digest = None
        self.raw = raw_manifest
        self.json = json.loads(raw_manifest)
        self.schema_version = self.json.get("schemaVersion", 0)
        self.media_type = self.json.get("mediaType", "")

        self._compute_digest()

    # Based on joseBase64UrlDecode() from docker
    @staticmethod
    def _jose_base64url_decode(data):
        # Strip whitespace
        data.replace("\n", "")
        data.replace(" ", "")
        # Pad data with = to make it valid
        rem = len(data) % 4
        if rem > 0:
            data += "=" * (4 - rem)
        return base64.urlsafe_b64decode(data)

    def _compute_digest(self):
        raw = self.raw

        # If this is an old docker v1 signed manifest we need to remove the jsw signature
        if self.schema_version == 1 and "signatures" in self.json:
            formatLength = 0
            formatTail = ""
            for s in self.json["signatures"]:
                header = json.loads(ImageManifest._jose_base64url_decode(s["protected"]))
                formatLength = header["formatLength"]
                formatTail = ImageManifest._jose_base64url_decode(header["formatTail"])
            raw = raw[0:formatLength] + formatTail

        self.digest = "sha256:" + hashlib.sha256(raw).hexdigest()

    def is_manifest_list(self):
        return containers.is_manifest_list(self.json)

    def _match_platform(self, wanted_arch, wanted_os, wanted_variant):
        for m in self.json.get("manifests", []):
            platform = m.get("platform", {})
            arch = platform.get("architecture", "")
            ostype = platform.get("os", "")
            variant = platform.get("variant", None)

            if arch != wanted_arch or wanted_os != ostype:
                continue

            if wanted_variant and wanted_variant != variant:
                continue

            return m["digest"]

        return None

    def resolve_list(self, wanted_arch, wanted_os, wanted_variant):
        if not self.is_manifest_list():
            return self

        digest = None

        if wanted_variant:
            # Variant specify, require exact match
            digest = self._match_platform(wanted_arch, wanted_os, wanted_variant)
        else:
            # No variant specified, first try exact match with default variant for arch (if any)
            default_variant = ImageManifest._default_variant.get(wanted_arch, None)
            if default_variant:
                digest = self._match_platform(wanted_arch, wanted_os, default_variant)

            # Else, pick first with any (or no) variant
            if not digest:
                digest = self._match_platform(wanted_arch, wanted_os, None)

        if not digest:
            raise RuntimeError(
                f"No manifest matching architecture '{wanted_arch}', os '{wanted_os}', variant '{wanted_variant}'.")

        return ImageManifest.load(self.name, digest=digest)

    def get_config_digest(self):
        if self.schema_version == 1:
            # The way the image id is extracted for old v1 images is super weird, and
            # there is no easy way to get it from skopeo.
            # So, kets just not support them instead of living in the past.
            raise RuntimeError("Old docker images with schema version 1 not supported.")
        if self.is_manifest_list():
            raise RuntimeError("No config existis for manifest lists.")

        return self.json.get("config", {}).get("digest", "")


# pylint: disable=too-many-ancestors
class YamlOrderedLoader(yaml.Loader):
    def construct_mapping(self, node, deep=False):
        if not isinstance(node, yaml.MappingNode):
            raise yaml.constructor.ConstructorError(None, None,
                                                    f"expected a mapping node, but found {node.id}",
                                                    node.start_mark)
        mapping = collections.OrderedDict()
        for key_node, value_node in node.value:
            key = self.construct_object(key_node, deep=deep)
            if not isinstance(key, collections.abc.Hashable):
                raise yaml.constructor.ConstructorError("while constructing a mapping", node.start_mark,
                                                        "found unhashable key", key_node.start_mark)
            value = self.construct_object(value_node, deep=deep)
            mapping[key] = value
        return mapping

    def construct_yaml_map(self, node):
        data = collections.OrderedDict()
        yield data
        value = self.construct_mapping(node)
        data.update(value)


yaml.add_constructor('tag:yaml.org,2002:map', YamlOrderedLoader.construct_yaml_map)


def yaml_load_ordered(source):
    return yaml.load(source, YamlOrderedLoader)


def json_load_ordered(source):
    return json.load(source, object_pairs_hook=collections.OrderedDict)


def element_enter(element, key, default):
    if key not in element:
        element[key] = default.copy()
    return element[key]


class EmbeddedFile:
    def __init__(self) -> None:
        pass


class PkgInfo:
    def __init__(self, checksum, name, evr, arch):
        self.checksum = checksum
        self.name = name
        self.evr = evr
        self.arch = arch
        self.url = None
        self.secrets = None

    @classmethod
    def from_dnf_package(cls, pkg: dnf.package.Package):
        checksum_type = hawkey.chksum_name(pkg.chksum[0])
        checksum_hex = pkg.chksum[1].hex()

        checksum = f"{checksum_type}:{checksum_hex}"

        return cls(checksum, pkg.name, pkg.evr, pkg.arch)

    @property
    def evra(self):
        return f"{self.evr}.{self.arch}"

    @property
    def nevra(self):
        return f"{self.name}-{self.evra}"

    def __str__(self):
        return self.nevra


class PacmanSolver():

    def __init__(self, cachedir, persistdir):
        self._cachedir = cachedir or "/tmp/pacsolve"
        self._persistdir = persistdir

    def setup_root(self):
        root = self._cachedir
        os.makedirs(root, exist_ok=True)
        os.makedirs(os.path.join(root, "var", "lib", "pacman"), exist_ok=True)
        os.makedirs(os.path.join(root, "etc"), exist_ok=True)

    def reset(self, arch, _, _module_platform_id, _ignore_weak_deps):
        self.setup_root()
        cfg = f"""
[options]
Architecture = {arch}
CheckSpace
SigLevel    = Required DatabaseOptional
LocalFileSigLevel = Optional
"""
        cfgpath = os.path.join(self._cachedir, "etc", "pacman.conf")
        with open(cfgpath, "w", encoding="utf-8") as cfgfile:
            cfgfile.write(cfg)

    def add_repo(self, desc, _):
        rid = desc["id"]
        url = desc["baseurl"]
        cfgpath = os.path.join(self._cachedir, "etc", "pacman.conf")
        with open(cfgpath, "a", encoding="utf-8") as cfgfile:
            cfgfile.write("\n")
            cfgfile.write(f"[{rid}]\n")
            cfgfile.write(f"Server = {url}\n")

    @staticmethod
    def _pacman(*args):
        return subprocess.check_output(["pacman", *args], encoding="utf-8")

    def resolve(self, packages, _):
        self._pacman("-Sy", "--root", self._cachedir, "--config", os.path.join(self._cachedir, "etc", "pacman.conf"))
        res = self._pacman("-S", "--print", "--print-format", r'{"url": "%l", "version": "%v", "name": "%n"},',
                           "--sysroot", self._cachedir, *packages)
        res = "[" + res.strip().rstrip(",") + "]"
        data = json.loads(res)
        packages = []
        for pkg in data:
            pkginfo = self._pacman("-Sii", "--sysroot", self._cachedir, pkg["name"])
            pkgdata = self.parse_pkg_info(pkginfo)
            p = PkgInfo(
                "sha256:" + pkgdata["SHA-256 Sum"],
                pkg["name"],
                pkg["version"],
                pkgdata["Architecture"],
            )
            p.url = pkg["url"]
            packages.append(p)
        return packages

    @staticmethod
    def parse_pkg_info(info):
        lines = info.split("\n")

        def parse_line(l):
            k, v = l.split(":", maxsplit=1)
            return k.strip(), v.strip()
        return dict([parse_line(line) for line in lines if ":" in line])


class DepSolver:
    def __init__(self, cachedir, persistdir):
        self.cachedir = cachedir
        self.persistdir = persistdir
        self.basedir = None

        self.subscriptions = None
        self.secrets = {}

        self.base = dnf.Base()

    def reset(self, arch, basedir, module_platform_id, ignore_weak_deps):
        base = self.base
        base.reset(goal=True, repos=True, sack=True)
        self.secrets.clear()

        if self.cachedir:
            base.conf.cachedir = self.cachedir
        base.conf.config_file_path = "/dev/null"
        base.conf.persistdir = self.persistdir
        base.conf.module_platform_id = module_platform_id
        base.conf.install_weak_deps = not ignore_weak_deps
        base.conf.arch = arch

        self.base = base
        self.basedir = basedir

    def expand_baseurl(self, baseurl):
        """Expand non-uris as paths relative to basedir into a file:/// uri"""
        basedir = self.basedir
        try:
            result = urllib.parse.urlparse(baseurl)
            if not result.scheme:
                path = basedir.joinpath(baseurl)
                return path.resolve().as_uri()
        except BaseException:  # pylint: disable=bare-except
            pass

        return baseurl

    def get_secrets(self, url, desc):
        if not desc:
            return None

        name = desc.get("name")
        if name != "org.osbuild.rhsm":
            raise ValueError(f"Unknown secret type: {name}")

        try:
            # rhsm secrets only need to be retrieved once and can then be reused
            if not self.subscriptions:
                self.subscriptions = Subscriptions.from_host_system()
            secrets = self.subscriptions.get_secrets(url)
        except RuntimeError as e:
            raise ValueError(f"Error getting secrets: {e.args[0]}") from None

        secrets["type"] = "org.osbuild.rhsm"

        return secrets

    def add_repo(self, desc, baseurl):
        repo = dnf.repo.Repo(desc["id"], self.base.conf)
        url = None
        url_keys = ["baseurl", "metalink", "mirrorlist"]
        skip_keys = ["id", "secrets"]
        supported = ["baseurl", "metalink", "mirrorlist",
                     "enabled", "metadata_expire", "gpgcheck", "username", "password", "priority",
                     "sslverify", "sslcacert", "sslclientkey", "sslclientcert",
                     "skip_if_unavailable"]

        for key in desc.keys():
            if key in skip_keys:
                continue  # We handled this already

            if key in url_keys:
                url = desc[key]
            if key in supported:
                value = desc[key]
                if key == "baseurl":
                    value = self.expand_baseurl(value)
                setattr(repo, key, value)
            else:
                raise ValueError(f"Unknown repo config option {key}")

        if not url:
            url = self.expand_baseurl(baseurl)

        if not url:
            raise ValueError("repo description does not contain baseurl, metalink, or mirrorlist keys")

        secrets = self.get_secrets(url, desc.get("secrets"))

        if secrets:
            if "ssl_ca_cert" in secrets:
                repo.sslcacert = secrets["ssl_ca_cert"]
            if "ssl_client_key" in secrets:
                repo.sslclientkey = secrets["ssl_client_key"]
            if "ssl_client_cert" in secrets:
                repo.sslclientcert = secrets["ssl_client_cert"]
            self.secrets[repo.id] = secrets["type"]

        self.base.repos.add(repo)

        return repo

    def resolve(self, packages, excludes):
        base = self.base

        base.reset(goal=True, sack=True)
        base.fill_sack(load_system_repo=False)

        base.install_specs(packages, exclude=excludes)
        base.resolve()

        deps = []

        for tsi in base.transaction:
            if tsi.action not in dnf.transaction.FORWARD_ACTIONS:
                continue

            path = tsi.pkg.relativepath
            reponame = tsi.pkg.reponame
            baseurl = self.base.repos[reponame].baseurl[0]
            baseurl = self.expand_baseurl(baseurl)
            # dep["path"] often starts with a "/", even though it's meant to be
            # relative to `baseurl`. Strip any leading slashes, but ensure there's
            # exactly one between `baseurl` and the path.
            url = urllib.parse.urljoin(baseurl + "/", path.lstrip("/"))

            pkg = PkgInfo.from_dnf_package(tsi.pkg)
            pkg.url = url
            pkg.secrets = self.secrets.get(reponame)

            deps.append(pkg)

        return deps


class DepSolverFactory():

    def __init__(self, cachedir, persistdir):
        self._cachedir = cachedir
        self._persistdir = persistdir
        self._solvers = {}

    def get_depsolver(self, solver):
        if solver not in self._solvers:
            if solver == "alpm":
                klass = PacmanSolver
            else:
                klass = DepSolver
            self._solvers[solver] = klass(self._cachedir, self._persistdir)
        return self._solvers[solver]


class Partition:
    def __init__(self,
                 uid: str = None,
                 pttype: str = None,
                 start: int = None,
                 size: int = None,
                 bootable: bool = False,
                 name: str = None,
                 uuid: str = None,
                 attrs: List[int] = None):
        self.id = uid
        self.type = pttype
        self.start = start
        self.size = size
        self.bootable = bootable
        self.name = name
        self.uuid = uuid
        self.attrs = attrs
        self.index = None

    @property
    def start_in_bytes(self):
        return (self.start or 0) * 512

    @property
    def size_in_bytes(self):
        return (self.size or 0) * 512

    @classmethod
    def from_dict(cls, js):
        p = cls(uid=js.get("id"),
                pttype=js.get("type"),
                start=js.get("start"),
                size=js.get("size"),
                bootable=js.get("bootable"),
                name=js.get("name"),
                uuid=js.get("uuid"),
                attrs=js.get("attrs"))
        return p

    def to_dict(self):
        data = {}

        if self.start:
            data["start"] = self.start
        if self.size:
            data["size"] = self.size
        if self.type:
            data["type"] = self.type
        if self.bootable:
            data["bootable"] = self.bootable
        if self.name:
            data["name"] = self.name
        if self.uuid:
            data["uuid"] = self.uuid
        if self.attrs:
            data["attrs"] = list(self.attrs)

        return data


class PartitionTable:
    def __init__(self, label, uuid, partitions):
        self.label = label
        self.uuid = uuid
        self.partitions = partitions or []

    def __getitem__(self, key) -> Partition:
        if isinstance(key, int):
            return self.partitions[key]
        if isinstance(key, str):
            for part in self.partitions:
                if part.id == key:
                    return part
        raise IndexError

    def write_to(self, target, sync=True):
        """Write the partition table to disk"""
        # generate the command for sfdisk to create the table
        command = f"label: {self.label}\nlabel-id: {self.uuid}"
        for partition in self.partitions:
            fields = []
            for field in ["start", "size", "type", "name", "uuid", "attrs"]:
                value = getattr(partition, field)
                if not value:
                    continue
                if field == "attrs":
                    resv = {
                        0: "RequiredPartition",
                        1: "NoBlockIOProtocol",
                        2: "LegacyBIOSBootable"
                    }
                    attrs = []
                    for bit in value:
                        if bit in resv:
                            attrs.append(resv[bit])
                        elif 48 <= bit <= 63:
                            attrs.append(str(bit))
                    value = ",".join(attrs)
                fields += [f'{field}="{value}"']
            if partition.bootable:
                fields += ["bootable"]
            command += "\n" + ", ".join(fields)

        subprocess.run(["sfdisk", "-q", "--no-tell-kernel", target],
                       input=command,
                       encoding='utf-8',
                       check=True)

        if sync:
            self.update_from(target)

    def update_from(self, target):
        """Update and fill in missing information from disk"""
        r = subprocess.run(["sfdisk", "--json", target],
                           stdout=subprocess.PIPE,
                           encoding='utf-8',
                           check=True)
        disk_table = json.loads(r.stdout)["partitiontable"]
        disk_parts = disk_table["partitions"]

        assert len(disk_parts) == len(self.partitions)
        for i, part in enumerate(self.partitions):
            part.index = i
            part.start = disk_parts[i]["start"]
            part.size = disk_parts[i]["size"]
            part.type = disk_parts[i].get("type")
            part.name = disk_parts[i].get("name")

    @classmethod
    def from_dict(cls, js) -> Partition:
        ptuuid = js["uuid"]
        pttype = js["label"]
        partitions = js.get("partitions")

        parts = [Partition.from_dict(p) for p in partitions]
        table = cls(pttype, ptuuid, parts)

        return table

    def __str__(self) -> str:
        data = {}

        if self.uuid:
            data["uuid"] = self.uuid

        data["label"] = self.label

        data["partitions"] = [
            pt.to_dict() for pt in self.partitions
        ]

        return json.dumps(data, indent=2)


class Image:
    def __init__(self, size, layout):
        self.size = size
        self.layout = layout

    @classmethod
    def from_dict(cls, js):
        size = js["size"]
        data = js["table"]

        with tempfile.TemporaryDirectory() as tmp:
            image = os.path.join(tmp, "disk.img")
            subprocess.run(["truncate", "--size", size, image], check=True)

            table = PartitionTable.from_dict(data)
            table.write_to(image)

        return cls(size, table)


# pylint: disable=too-many-instance-attributes
class ManifestFile:
    @staticmethod
    def load(path, overrides, default_vars, searchdirs):
        with open(path, encoding="utf8") as f:
            return ManifestFile.load_from_fd(f, path, overrides, default_vars, searchdirs)

    @staticmethod
    def load_from_fd(f, path, overrides, default_vars, searchdirs):
        # We use OrderedDict to preserve key order (for python < 3.6)
        if path.endswith(".yml") or path.endswith(".yaml"):
            try:
                data = yaml_load_ordered(f)
            except yaml.YAMLError as err:
                prob = ""
                if hasattr(err, 'problem_mark'):
                    mark = err.problem_mark
                    prob = f": {err.problem} at line {mark.line+1} (col {mark.column+1})"
                print(f"Invalid yaml in \"{path}\"{prob}")
                sys.exit(1)
        else:
            try:
                data = json_load_ordered(f)
            except json.decoder.JSONDecodeError as err:
                print(f"Invalid json in \"{path}\": {err.msg} at line {err.lineno} (col {err.colno})")
                sys.exit(1)

        version = int(data.get("version", "1"))
        if version == 1:
            m = ManifestFileV1(path, overrides, default_vars, data, searchdirs)
        elif version == 2:
            m = ManifestFileV2(path, overrides, default_vars, data, searchdirs)
        else:
            raise ValueError(f"Unknown manfest version {version}")

        m.process_imports()
        m.process_partition()

        return m

    def __init__(self, path, overrides, default_vars, root, searchdirs, version):
        self.path = pathlib.Path(path)
        self.basedir = self.path.parent
        self.searchdirs = searchdirs
        self.root = root
        self.version = version
        self.sources = element_enter(self.root, "sources", {})
        self.source_urls = {}
        self.format_stack = []
        self.solver_factory = None

        self.vars = default_vars.copy()
        self.overrides = overrides
        self.init_vars()

    def get_mpp_node(self, parent: Dict, name: str) -> Optional[Dict]:
        name = "mpp-" + name

        desc = parent.get(name)
        if not desc:
            return None

        del parent[name]

        return self.substitute_vars(desc)

    def init_vars(self):
        variables = self.get_mpp_node(self.root, "vars")

        if not variables:
            return

        for k, v in variables.items():
            self.vars[k], _ = self._rewrite_node(v)
        self.substitute_vars(self.vars)

    def get_vars(self):
        return {**self.vars, **self.overrides}

    def substitute_vars(self, node):
        """Expand variables in `node` with the manifest variables"""

        if isinstance(node, dict):
            for k, v in node.items():
                node[k] = self.substitute_vars(v)
        elif isinstance(node, list):
            for i, v in enumerate(node):
                node[i] = self.substitute_vars(v)
        elif isinstance(node, str):
            tpl = string.Template(node)
            node = tpl.safe_substitute(self.get_vars())

        return node

    def load_import(self, path):
        m = self.find_and_load_manifest(path)
        if m.version != self.version:
            raise ValueError(f"Incompatible manifest version {m.version}")
        return m

    def find_and_open_file(self, path, dirs, mode="r", encoding="utf8"):
        for p in [self.basedir] + dirs:
            with contextlib.suppress(FileNotFoundError):
                fullpath = os.path.join(p, path)
                return open(fullpath, mode, encoding=encoding), os.path.normpath(fullpath)
        raise FileNotFoundError(f"Could not find file '{path}'")

    def find_and_load_manifest(self, path):
        f, fullpath = self.find_and_open_file(path, self.searchdirs)
        with f:
            return ManifestFile.load_from_fd(f, fullpath, self.overrides, self.vars, self.searchdirs)

    def depsolve(self, desc: Dict):
        repos = desc.get("repos", [])
        packages = desc.get("packages", [])
        excludes = desc.get("excludes", [])
        baseurl = desc.get("baseurl")
        arch = desc.get("architecture")
        solver = self.solver_factory.get_depsolver(desc.get("solver", "dnf"))

        if not packages:
            return []

        module_platform_id = desc["module-platform-id"]
        ignore_weak_deps = bool(desc.get("ignore-weak-deps"))

        solver.reset(arch, self.basedir, module_platform_id, ignore_weak_deps)

        for repo in repos:
            solver.add_repo(repo, baseurl)

        return solver.resolve(packages, excludes)

    def add_packages(self, deps, pipeline_name):
        checksums = []

        pkginfos = {}

        for dep in deps:
            name, checksum, url = dep.name, dep.checksum, dep.url

            pkginfos[name] = dep

            if dep.secrets:
                data = {
                    "url": url,
                    "secrets": {"name": dep.secrets}
                }
            else:
                data = url

            self.source_urls[checksum] = data
            checksums.append(checksum)

        if "rpms" not in self.vars:
            self.vars["rpms"] = {}
        self.vars["rpms"][pipeline_name] = pkginfos

        return checksums

    def sort_urls(self):
        def get_sort_key(item):
            key = item[1]
            if isinstance(key, dict):
                key = key["url"]
            return key

        urls = self.source_urls
        if not urls:
            return urls

        urls_sorted = sorted(urls.items(), key=get_sort_key)
        urls.clear()
        urls.update(collections.OrderedDict(urls_sorted))

    def write(self, file, sort_keys=False):
        self.sort_urls()
        json.dump(self.root, file, indent=2, sort_keys=sort_keys)
        file.write("\n")

    def _rewrite_node(self, node):
        fakeroot = [node]
        self._process_format(fakeroot)
        if not fakeroot:
            return None, True
        return fakeroot[0], False

    def _format_dict_node(self, node, stack):
        if len(stack) > 0:
            parent_node = stack[-1][0]
            parent_key = stack[-1][1]
        else:
            parent_node = None
            parent_key = None

        # Avoid unnecessarily running the stage processing on things
        # that don't look like a stage. The indidual stage processing
        # will verify that the stage looks right too.
        if parent_key == "stages":
            pipeline_name = self.get_pipeline_name(parent_node)
            self._process_stage(node, pipeline_name)

    # pylint: disable=too-many-branches
    def _process_format(self, node):
        def _is_format(node):
            if not isinstance(node, dict):
                return False
            for m in ("mpp-eval", "mpp-join", "mpp-if"):
                if m in node:
                    return True
            for m in ("int", "string", "json"):
                if f"mpp-format-{m}" in node:
                    return True
            return False

        def _eval_format(node, local_vars):
            if "mpp-join" in node:
                to_merge_list = node["mpp-join"]
                self._process_format(to_merge_list)
                res = []
                for to_merge in to_merge_list:
                    res.extend(to_merge)
                return res, False

            if "mpp-if" in node:
                code = node["mpp-if"]

                # pylint: disable=eval-used  # yolo this is fine!
                # Note, we copy local_vars here to avoid eval modifying it
                if eval(code, dict(local_vars)):
                    key = "then"
                else:
                    key = "else"

                if key in node:
                    return self._rewrite_node(node[key])
                return None, True

            if "mpp-eval" in node:
                code = node["mpp-eval"]

                # pylint: disable=eval-used  # yolo this is fine!
                # Note, we copy local_vars here to avoid eval modifying it
                res = eval(code, dict(local_vars))
                return res, False

            if "mpp-format-string" in node:
                res_type = "string"
                format_string = node["mpp-format-string"]
            elif "mpp-format-json" in node:
                res_type = "json"
                format_string = node["mpp-format-json"]
            else:
                res_type = "int"
                format_string = node["mpp-format-int"]

            # pylint: disable=eval-used  # yolo this is fine!
            # Note, we copy local_vars here to avoid eval modifying it
            res = eval(f'f\'\'\'{format_string}\'\'\'', dict(local_vars))

            if res_type == "int":
                res = int(res)
            elif res_type == "json":
                res = json.loads(res)
            return res, False

        if isinstance(node, dict):

            self._format_dict_node(node, self.format_stack)

            for key in list(node.keys()):
                self.format_stack.append((node, key))
                value = node[key]
                if _is_format(value):
                    val, remove = _eval_format(value, self.get_vars())
                    if remove:
                        del node[key]
                    else:
                        node[key] = val
                else:
                    self._process_format(value)
                self.format_stack.pop()

        if isinstance(node, list):
            to_remove = []
            for i, value in enumerate(node):
                if _is_format(value):
                    val, remove = _eval_format(value, self.get_vars())
                    if remove:
                        to_remove.append(i)
                    else:
                        node[i] = val
                else:
                    self._process_format(value)
            for i in reversed(to_remove):
                del node[i]

    def process_format(self):
        self._process_format(self.root)

    def process_partition(self):
        desc = self.get_mpp_node(self.root, "define-image")

        if not desc:
            return

        self._process_format(desc)

        name = desc.get("id", "image")
        self.vars[name] = Image.from_dict(desc)

    # pylint: disable=no-self-use
    def get_pipeline_name(self, node):
        return node.get("name", "")

    def _process_stage(self, stage, pipeline_name):
        self._process_depsolve(stage, pipeline_name)
        self._process_embed_files(stage)
        self._process_container(stage)

    def _process_depsolve(self, _stage, _pipeline_name):
        raise NotImplementedError()

    def _process_embed_files(self, _stage):
        raise NotImplementedError()

    def _process_container(self, _stage):
        raise NotImplementedError()


class ManifestFileV1(ManifestFile):
    def __init__(self, path, overrides, default_vars, data, searchdirs):
        super().__init__(path, overrides, default_vars, data, searchdirs, 1)
        self.pipeline = element_enter(self.root, "pipeline", {})

        files = element_enter(self.sources, "org.osbuild.files", {})
        self.source_urls = element_enter(files, "urls", {})

    def _process_import(self, build):
        mpp = self.get_mpp_node(build, "import-pipeline")
        if not mpp:
            return

        path = mpp["path"]
        imp = self.load_import(path)

        self.vars.update(imp.vars)

        # We only support importing manifests with URL sources. Other sources are
        # not supported, yet. This can be extended in the future, but we should
        # maybe rather try to make sources generic (and repeatable?), so we can
        # deal with any future sources here as well.
        assert list(imp.sources.keys()) == ["org.osbuild.files"]

        # We import `sources` from the manifest, as well as a pipeline description
        # from the `pipeline` entry. Make sure nothing else is in the manifest, so
        # we do not accidentally miss new features.
        assert sorted(imp.root) == sorted(["pipeline", "sources"])

        # Now with everything imported and verified, we can merge the pipeline back
        # into the original manifest. We take all URLs and merge them in the pinned
        # url-array, and then we take the pipeline and simply override any original
        # pipeline at the position where the import was declared.

        self.source_urls.update(imp.source_urls)

        build["pipeline"] = imp.pipeline

    def process_imports(self):
        current = self.root
        while current:
            self._process_import(current)
            current = current.get("pipeline", {}).get("build")

    def _process_depsolve(self, stage, pipeline_name):
        if stage.get("name", "") not in ("org.osbuild.pacman", "org.osbuild.rpm"):
            return
        options = stage.get("options")
        if not options:
            return

        mpp = self.get_mpp_node(options, "depsolve")
        if not mpp:
            return

        self._process_format(mpp)

        packages = element_enter(options, "packages", [])

        deps = self.depsolve(mpp)
        checksums = self.add_packages(deps, pipeline_name)

        packages += checksums

    def get_pipeline_name(self, node):
        if self.pipeline == node:
            return "stages"

        build = self.pipeline.get("build", {}).get("pipeline")
        if build == node:
            return "build"

        depth = 1
        while build:
            build = build.get("build", {}).get("pipeline")
            depth = depth + 1
            if build == node:
                return "build" + str(depth)

        return ""

    def _process_embed_files(self, stage):
        "Embedding files is not supported for v1 manifests"

    def _process_container(self, stage):
        "Installing containers is not supported for v1 manifests"


class ManifestFileV2(ManifestFile):
    def __init__(self, path, overrides, default_vars, data, searchdirs):
        super().__init__(path, overrides, default_vars, data, searchdirs, 2)
        self.pipelines = element_enter(self.root, "pipelines", [])

        files = element_enter(self.sources, "org.osbuild.curl", {})
        self.source_urls = element_enter(files, "items", {})

    def get_pipeline_by_name(self, name):
        for pipeline in self.pipelines:
            if pipeline["name"] == name:
                return pipeline

        raise ValueError(f"Pipeline '{name}' not found in {self.path}")

    def _process_import(self, pipeline):
        mpp = self.get_mpp_node(pipeline, "import-pipelines")
        if mpp:
            ids = mpp.get("ids")
        else:
            mpp = self.get_mpp_node(pipeline, "import-pipeline")
            if not mpp:
                return [pipeline]  # Not an import
            ids = [mpp["id"]]

        path = mpp["path"]
        imp = self.load_import(path)

        self.vars.update(imp.vars)

        for source, desc in imp.sources.items():
            target = self.sources.get(source)
            if not target:
                # new source, just copy everything
                self.sources[source] = desc
                continue

            if desc.get("options"):
                options = element_enter(target, "options", {})
                options.update(desc["options"])

            items = element_enter(target, "items", {})
            items.update(desc.get("items", {}))

        # Copy order from included file
        imp_pipelines = []
        for imp_pipeline in imp.pipelines:
            if not ids or imp_pipeline.get("name") in ids:
                # Merge whatever keys was in the mpp-import-pipelines into the imported pipelines
                imp_pipelines.append({**pipeline, **imp_pipeline})
        return imp_pipelines

    def process_imports(self):
        old_pipelines = self.pipelines.copy()
        self.pipelines.clear()
        for pipeline in old_pipelines:
            self.pipelines.extend(self._process_import(pipeline))

    def _process_depsolve(self, stage, pipeline_name):
        if stage.get("type", "") not in ("org.osbuild.pacman", "org.osbuild.rpm"):
            return
        inputs = element_enter(stage, "inputs", {})
        packages = element_enter(inputs, "packages", {})

        mpp = self.get_mpp_node(packages, "depsolve")
        if not mpp:
            return

        self._process_format(mpp)

        refs = element_enter(packages, "references", {})

        deps = self.depsolve(mpp)
        checksums = self.add_packages(deps, pipeline_name)

        for checksum in checksums:
            refs[checksum] = {}

    def _process_embed_files(self, stage):

        class Embedded(collections.namedtuple("Embedded", ["id", "checksum"])):
            def __str__(self):
                return self.checksum

        def embed_data(ip, mpp):
            uid = mpp["id"]
            path = mpp.get("path")
            url = mpp.get("url")
            text = mpp.get("text")

            input_count = bool(text) + bool(path) + bool(url)
            if input_count == 0:
                raise ValueError(f"At least one of 'path', 'url' or 'text' must be specified for '{uid}'")
            if input_count > 1:
                raise ValueError(f"Only one of 'path', 'url' or 'text' may be specified for '{uid}'")

            if path:
                f, _ = self.find_and_open_file(path, [], mode="rb", encoding=None)
                with f:
                    data = f.read()
            elif url:
                response = urllib.request.urlopen(url)
                data = response.fp.read()
            else:
                data = bytes(text, "utf-8")

            checksum = hashlib.sha256(data).hexdigest()
            digest = "sha256:" + checksum

            if url:
                source = element_enter(self.sources, "org.osbuild.curl", {})
                items = element_enter(source, "items", {})
                items[digest] = url
            else:
                encoded = base64.b64encode(data).decode("utf-8")
                source = element_enter(self.sources, "org.osbuild.inline", {})
                items = element_enter(source, "items", {})
                items[digest] = {
                    "encoding": "base64",
                    "data": encoded
                }

            refs = element_enter(ip, "references", {})
            refs[digest] = mpp.get("options", {})
            ef = element_enter(self.vars, "embedded", {})
            ef[uid] = Embedded(uid, digest)

        for ip in stage.get("inputs", {}).values():
            if ip.get("type") != "org.osbuild.files":
                continue

            if ip.get("origin") != "org.osbuild.source":
                continue

            mpp = self.get_mpp_node(ip, "embed")
            if not mpp:
                continue

            embed_data(ip, mpp)

    def _process_container(self, stage):
        if stage.get("type", "") != "org.osbuild.skopeo":
            return

        inputs = element_enter(stage, "inputs", {})
        inputs_images = element_enter(inputs, "images", {})

        if inputs_images.get("type", "") != "org.osbuild.containers":
            return

        if inputs_images.get("origin", "") != "org.osbuild.source":
            return

        mpp = self.get_mpp_node(inputs_images, "resolve-images")
        if not mpp:
            return

        refs = element_enter(inputs_images, "references", {})
        manifest_lists = []

        for image in element_enter(mpp, "images", []):
            source = image["source"]
            name = image.get("name", source)
            digest = image.get("digest", None)
            tag = image.get("tag", None)
            index = image.get("index", False)

            main_manifest = ImageManifest.load(source, tag=tag, digest=digest)

            ostype = image.get("os", "linux")

            default_rpm_arch = self.get_vars()["arch"]
            rpm_arch = image.get("arch", default_rpm_arch)
            oci_arch = ImageManifest.arch_from_rpm(rpm_arch)

            variant = image.get("variant", None)

            resolved_manifest = main_manifest.resolve_list(oci_arch, ostype, variant)

            image_id = resolved_manifest.get_config_digest()

            container_image_source = element_enter(self.sources, "org.osbuild.skopeo", {})
            items = element_enter(container_image_source, "items", {})
            items[image_id] = {
                "image": {
                    "name": source,
                    "digest": resolved_manifest.digest,
                }
            }

            refs[image_id] = {
                "name": name
            }

            if index:
                manifest_lists.append(main_manifest.digest)
                container_index_source = element_enter(self.sources, "org.osbuild.skopeo-index", {})
                index_items = element_enter(container_index_source, "items", {})
                index_items[main_manifest.digest] = {
                    "image": {
                        "name": source
                    }
                }

        # if we collected manifest lists, create the manifest-lists input array for the stage
        if manifest_lists:
            inputs_manifests = element_enter(inputs, "manifest-lists", {})
            inputs_manifests["type"] = "org.osbuild.files"
            inputs_manifests["origin"] = "org.osbuild.source"
            inputs_manifests["references"] = manifest_lists


def main():
    parser = argparse.ArgumentParser(description="Manifest pre processor")
    parser.add_argument(
        "--cache",
        "--dnf-cache",
        dest="cachedir",
        metavar="PATH",
        type=os.path.abspath,
        default=None,
        help="Path to package cache-directory to use",
    )
    parser.add_argument(
        "-I", "--import-dir",
        dest="searchdirs",
        default=[],
        action="append",
        help="Search for import in that directory",
    )
    parser.add_argument(
        "--sort-keys",
        dest="sort_keys",
        action='store_true',
        help="Sort keys in generated json",
    )
    parser.add_argument(
        "-D", "--define",
        default=[],
        dest="vars",
        action='append',
        help="Set/Override variable, format is key=Json"
    )
    parser.add_argument(
        "src",
        metavar="SRCPATH",
        help="Input manifest",
    )
    parser.add_argument(
        "dst",
        metavar="DESTPATH",
        help="Output manifest",
    )

    args = parser.parse_args(sys.argv[1:])

    defaults = {
        "arch": rpm.expandMacro("%{_arch}")
    }

    # Override variables from the main of imported files
    overrides = {}
    for arg in args.vars:
        if "=" in arg:
            key, value_s = arg.split("=", 1)
            value = json.loads(value_s)
        else:
            key = arg
            value = True
        overrides[key] = value

    m = ManifestFile.load(args.src, overrides, defaults, args.searchdirs)

    with tempfile.TemporaryDirectory() as persistdir:
        m.solver_factory = DepSolverFactory(args.cachedir, persistdir)
        m.process_format()
        m.solver_factory = None

    with sys.stdout if args.dst == "-" else open(args.dst, "w", encoding="utf8") as f:
        m.write(f, args.sort_keys)


if __name__ == "__main__":
    main()
