#!/usr/bin/env python

# Copyright (c) 2017 Red Hat, Inc.
#
# This software is licensed to you under the GNU General Public License,
# version 2 (GPLv2). There is NO WARRANTY for this software, express or
# implied, including the implied warranties of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. You should have received a copy of GPLv2
# along with this software; if not, see
# http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt.
#
# Red Hat trademarks are not licensed under GPLv2. No permission is
# granted to use or replicate Red Hat trademarks that are incorporated
# in this software or its documentation.
#

from collections import defaultdict
from xml.etree import ElementTree
import logging
import os
import os.path
import sys
import time

from zypp import RepoType
from zypp_plugin import Plugin
import zypp

from subscription_manager import logutil
from subscription_manager.productid import ProductManager
from subscription_manager.injectioninit import init_dep_injection

REPOMD_NS = 'http://linux.duke.edu/metadata/repo'

log = logging.getLogger('rhsm-app.' + __name__)


class ProductIdPlugin(Plugin):
    def PLUGINBEGIN(self, headers, body):
        self.ack()

    def COMMITBEGIN(self, headers, body):
        self.ack()

    def COMMITEND(self, headers, body):
        self.ack()

    def PLUGINEND(self, headers, body):
        """Spawn a daemon that will update product certs after zypper exits.

        Because zypper locks itself, we cannot do the queries we need inside
        the plugin itself. So, we fork and daemonize in order to outlast the
        zypper process.
        """
        if os.fork() == 0:
            try:
                os.setsid()
                init_dep_injection()
                manager = ZypperProductManager()
                manager.update_all()
            except Exception as e:
                log.error('Issue updating product certs')
                log.exception(e)
        else:
            self.ack()


def retry(function):
    """Retry in the case of exceptions for a max of ~30 seconds."""
    timeout = 30
    total_slept = 0
    current_timeout = 1
    while total_slept < timeout:  # TODO configurable timeout?
        try:
            return function()
        except:
            time.sleep(current_timeout)
            total_slept += current_timeout
            current_timeout = current_timeout * 2  # in the spirit of exponential backoff
            if total_slept > timeout:
                raise


class ZypperProductManager(ProductManager):
    def __init__(self):
        ProductManager.__init__(self)
        self.zypp = retry(zypp.ZYppFactory_instance().getZYpp)  # retry in case zypper lock is still held
        self.zypp.initializeTarget(zypp.Pathname("/"))
        self.zypp.target().load()
        self.repo_manager = zypp.RepoManager()

    def update_all(self):
        return self.update(self.get_enabled(), self.get_active(), False)

    def get_enabled(self):
        certs_and_repos = []
        enabled = [repo for repo in self.repo_manager.knownRepositories() if repo.enabled() and repo.type().toEnum() == RepoType.RPMMD_e]
        for repo in enabled:
            try:
                if not self.repo_manager.isCached(repo):
                    self.repo_manager.buildCache(repo)
                self.repo_manager.loadFromCache(repo)

                repo_id = repo.alias()
                filename = self.retrieve_cert(repo)
                if filename is None:
                    log.debug("Repo %s has no productid.", repo.alias())
                    continue
                cert = self._get_cert(filename)
                if cert is None:
                    continue
                certs_and_repos.append((cert, repo_id))
            except Exception as e:
                log.warn("Error loading productid metadata for %s." % repo.alias())
                log.exception(e)
                self.meta_data_errors.append(repo_id)

        if self.meta_data_errors:
            log.debug("Unable to load productid metadata for repos: %s", self.meta_data_errors)
        return certs_and_repos

    def retrieve_cert(self, repo):
        """Find the pathname to the local copy of productid.

        When zypper downloads the repo metadata, it appears to download all files defined in repomd.xml :-)
        """
        repomd_path = os.path.join(repo.metadataPath().c_str(), 'repodata', 'repomd.xml')
        repomd = ElementTree.parse(repomd_path).getroot()
        productid_data_list = [data for data in repomd.findall('{%s}data' % REPOMD_NS) if data.get('type') == 'productid']
        if len(productid_data_list) == 0:
            return None
        productid_data = productid_data_list[0]
        return os.path.join(repo.metadataPath().c_str(), productid_data.find('{%s}location' % REPOMD_NS).get('href'))

    def get_active(self):
        active = set([])

        package_repos_map = defaultdict(list)
        installed_packages = set([])

        for package in self.zypp.pool():
            # zypper uses name, edition, arch, vendor and buildtime to decide that it's equivalent
            package_identity = (package.name(), str(package.edition()), str(package.arch()), str(package.vendor()), str(package.buildtime()))
            package_repos_map[package_identity].append(package.repoInfo().alias())
            if package.status().isInstalled():
                installed_packages.add(package_identity)

        for package in installed_packages:
            for repo in package_repos_map[package]:
                active.add(repo)

        return active

if __name__ == '__main__':
    logutil.init_logger()
    plugin = ProductIdPlugin()
    plugin.main()
