#!/usr/bin/python3
# NOTE: if mini-tps is packaged and built for RHEL 8, which comes with the platform python,
# the shebang above will turn into "#!/usr/libexec/platform-python" automatically.

# -*- coding: utf-8 -*-
# vim: set filetype=python fileencoding=UTF-8 shiftwidth=2 tabstop=2 expandtab

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; version 2 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Library General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, see <https://gnu.org/licenses/>.

# Copyright: Red Hat Inc. 2019
# Author: Andrei Stepanov <astepano@redhat.com>
#
# https://github.com/fedora-modularity/libmodulemd

import argparse
import gi
import gzip
import logging
import sys

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

md1 = None
md2 = None

# gi doesn't allow load two Modulemd at the same time

try:
    gi.require_version("Modulemd", "2.0")  # noqa
    from gi.repository import Modulemd as md2
except Exception:
    pass

try:
    gi.require_version("Modulemd", "1.0")  # noqa
    from gi.repository import Modulemd as md1
except Exception:
    pass

if not md1 and not md2:
    logger.debug("Cannot load modulemd library.")
    sys.exit(1)

parser = argparse.ArgumentParser(description="A set of utils for mtps..")
group = parser.add_mutually_exclusive_group(required=True)
parser.add_argument(
    "-m",
    "--modulemd",
    help="Modulemd file path. Can be .gz",
    default="modulemd.txt",
    metavar="FILE",
)
group.add_argument(
    "-f",
    "--getfiltered",
    action="store_true",
    help="Print filtered packages from module.",
)
group.add_argument(
    "-a", "--getfartifacts", action="store_true", help="Print artifacts from module."
)
group.add_argument(
    "--getprofiles", action="store_true", help="Print profiles from module."
)
group.add_argument("--getstream", action="store_true", help="Print stream from module.")
group.add_argument(
    "--getcontext", action="store_true", help="Print context from module."
)
group.add_argument(
    "--getversion", action="store_true", help="Print version from module."
)
group.add_argument("--getnsvc", action="store_true", help="Print n:s:v:c from module.")
group.add_argument(
    "--getrequires", action="store_true", help="Print list n:s required modules."
)
args = parser.parse_args()
mmds = []

if args.modulemd:
    if args.modulemd.endswith(".gz"):
        f = gzip.open(args.modulemd, "rb")
    else:
        f = open(args.modulemd, "rb")
    md = f.read().decode(encoding="utf-8")
    f.close()

if md2:

    def get_nsvc(module):
        nsvc = "%s:%s:%s:%s" % (
            module.get_module_name(),
            module.get_stream_names()[0],
            module.get_all_streams()[0].get_version(),
            module.get_all_streams()[0].get_context(),
        )
        return nsvc

    # 1 brew build == 1 stream
    index = md2.ModuleIndex.new()
    ret, failures = index.update_from_string(md, True)
    module_names = index.get_module_names()
    logger.debug("Found modules in {}: {}".format(args.modulemd, len(module_names)))
    if args.getfiltered:
        for name in module_names:
            module = index.get_module(name)
            for stream in module.get_all_streams():
                for filtered in stream.get_rpm_filters():
                    print(filtered)
    if args.getfartifacts:
        for name in module_names:
            module = index.get_module(name)
            for stream in module.get_all_streams():
                for artifact in stream.get_rpm_artifacts():
                    print(artifact)
    if args.getprofiles:
        for name in module_names:
            module = index.get_module(name)
            for stream in module.get_all_streams():
                for profile in stream.get_profile_names():
                    print(profile)
    if args.getstream:
        for name in module_names:
            module = index.get_module(name)
            for stream in module.get_stream_names():
                print(stream)
    if args.getcontext:
        for name in module_names:
            module = index.get_module(name)
            for stream in module.get_all_streams():
                print(stream.get_context())
    if args.getversion:
        for name in module_names:
            module = index.get_module(name)
            for stream in module.get_all_streams():
                print(stream.get_version())
    if args.getnsvc:
        for name in module_names:
            module = index.get_module(name)
            print(get_nsvc(module))
    if args.getrequires:
        for name in module_names:
            module = index.get_module(name)
            for stream in module.get_all_streams():
                for dep in stream.get_dependencies():
                    for req_module in dep.get_runtime_modules():
                        if req_module == "platform":
                            continue
                        line = ""
                        for req_stream in dep.get_runtime_streams(req_module):
                            line += " " + req_module + ":" + req_stream
                        line = line.strip()
                        if line:
                            print(line)
                            # [expression(i) for i in old_list if filter(i)]
                            # [print("md_st") for req_stream in req_streams

if md1 and not md2:

    def get_nsvc(mmd):
        nsvc = "%s:%s:%s:%s" % (
            mmd.peek_name(),
            mmd.peek_stream(),
            mmd.peek_version(),
            mmd.peek_context(),
        )
        return nsvc

    mmds = md1.objects_from_string(md)
    logger.debug("Found modules in {}: {}".format(args.modulemd, len(mmds)))
    if args.getfiltered:
        for mmd in mmds:
            for filtered in mmd.peek_rpm_filter().get():
                print(filtered)
    if args.getfartifacts:
        for mmd in mmds:
            for artifact in mmd.peek_rpm_artifacts().get():
                print(artifact)
    if args.getprofiles:
        for mmd in mmds:
            for profile in mmd.peek_profiles().keys():
                print(profile)
    if args.getstream:
        for mmd in mmds:
            print(mmd.peek_stream())
    if args.getcontext:
        for mmd in mmds:
            print(mmd.peek_context())
    if args.getversion:
        for mmd in mmds:
            print(mmd.peek_version())
    if args.getnsvc:
        for mmd in mmds:
            nsvc = get_nsvc(mmd)
            logger.debug("n:s:v:c for module {}: {}".format(nsvc, nsvc))
            print(nsvc)
    if args.getrequires:
        for mmd in mmds:
            for dep in mmd.peek_dependencies():
                req_modules = dep.peek_requires()
                for req_module in req_modules:
                    if req_module == "platform":
                        continue
                    line = ""
                    for req_stream in req_modules[req_module].get():
                        line += " " + req_module + ":" + req_stream
                    line = line.strip()
                    if line:
                        print(line)
                        # [expression(i) for i in old_list if filter(i)]
                        # [print("md_st") for req_stream in req_streams
