#!/usr/bin/env python3
"""
    Labs for IPng Networks 

    (c) 2022- Pim van Pelt <pim@ipng.nl>

"""

from jinja2 import Environment, FileSystemLoader
from jinja2_ansible_filters import AnsibleCoreFiltersExtension

import hiyapyco
import traceback
import os
import sys
import pprint
import logging
import ipaddress
import re

try:
    import argparse
except ImportError:
    print("ERROR: install argparse manually")
    print("HINT: sudo pip install argparse")
    sys.exit(2)

log = logging.getLogger("generate")
log.setLevel(logging.INFO)
formatter = logging.Formatter(
    "[%(levelname)-8s] %(name)17s - %(funcName)-15s: %(message)s"
)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(formatter)
log.addHandler(ch)


def toyaml(d, indent=0, result=""):
    for key, value in d.items():
        result += " " * indent + str(key) + ": "
        if isinstance(value, dict):
            result = toyaml(value, indent + 2, result + "\n")
        else:
            if isinstance(value, str) and [
                e for e in [" ", ":", "{", "}", "[", "]", "#"] if e in value
            ]:
                result += "'" + str(value) + "'\n"
            else:
                result += str(value) + "\n"
    return result


def render(tpl_path, data, trim=True):
    path, filename = os.path.split(tpl_path)
    env = Environment(
        loader=FileSystemLoader(path or "./"), extensions=[AnsibleCoreFiltersExtension]
    )
    env.trim_blocks = trim
    env.lstrip_blocks = trim
    env.rstrip_blocks = trim
    env.filters["toyaml"] = toyaml

    return env.get_template(filename).render(data)


def tpl2fn(tpl, prefix):
    fn = tpl[len(prefix) :]
    if fn.endswith(".j2"):
        fn = fn[:-3]
    return fn


def find(file_or_dir_list):
    log.info("Finding files in %s" % file_or_dir_list)
    ret = {}
    for e in file_or_dir_list:
        if e.startswith("_"):
            continue
        if os.path.isfile(e):
            ret[e] = tpl2fn(e, e)
        elif os.path.isdir(e):
            for root, dirnames, filenames in os.walk(e):
                for filename in filenames:
                    if filename.startswith("_"):
                        continue
                    tpl = os.path.join(root, filename)
                    ret[tpl] = tpl2fn(tpl, e)

    log.debug("Templates: %s" % ret)
    return ret


def generate(files, data, debug=False):
    output = {}
    for tpl, fn in files.items():
        log.info("Rendering %s into %s" % (tpl, fn))
        try:
            output[fn] = render(tpl, data)
        except:
            log.error("Could not render %s!" % tpl)
            if debug:
                traceback.print_exc(file=sys.stderr)
            return None
    return output


def emit(output, outdir):
    log.debug("Emitting to %s" % outdir)
    for fn, contents in output.items():
        if outdir == "-":
            log.info("Emitting %s" % fn)
            print(contents)
            continue

        outfile = os.path.join(outdir, fn)
        log.info("Emitting %s into %s" % (fn, outfile))
        basedir = os.path.dirname(outfile)
        os.makedirs(basedir, exist_ok=True)
        f = open(outfile, "w")
        f.write(contents)
        f.close()


def prune(output, outdir):
    if outdir == "-":
        log.info("Skipping pruning, output is stdout")
        return True

    for root, dirnames, filenames in os.walk(outdir):
        for filename in filenames:
            fn = os.path.join(root, filename)  # build/frggh0.ipng.ch/bird/bird.conf
            rel_fn = fn.replace(outdir, "")  # /bird/bird.conf
            if rel_fn[0] == "/":
                rel_fn = rel_fn[1:]  # bird/bird.conf
            if not rel_fn in output:
                log.info("Pruning file %s (%s)" % (rel_fn, fn))
                os.remove(fn)

    for root, dirnames, filenames in os.walk(outdir):
        for dirname in dirnames:
            dn = os.path.join(root, dirname)  # build/frggh0.ipng.ch/bird/empty
            if not os.listdir(dn):
                log.info("Pruning dir %s" % (dn))
                os.rmdir(dn)


def create_node(lab, node_id, node_type):
    v4_base, v4_plen = lab["mgmt"]["ipv4"].split("/")
    v6_base, v6_plen = lab["mgmt"]["ipv6"].split("/")
    lo4_base = lab["ipv4"].split("/")[0]
    lo6_base = lab["ipv6"].split("/")[0]
    total_nodes = 0
    for nt, nc in lab["nodes"].items():
        if nt == node_type:
            offset = total_nodes
        total_nodes += nc

    ret = {
        "hostname": "%s%d-%d" % (node_type, lab["id"], node_id),
        "id": node_id,
        "mgmt": {
            "ipv4": "%s/%s"
            % (
                ipaddress.IPv4Address(v4_base) + total_nodes * lab["id"] + offset + node_id,
                v4_plen,
            ),
            "ipv6": "%s/%s"
            % (
                ipaddress.IPv6Address(v6_base) + total_nodes * lab["id"] + offset + node_id,
                v6_plen,
            ),
            "gw4": lab["mgmt"]["gw4"],
            "gw6": lab["mgmt"]["gw6"],
        },
        "loopback": {
            "ipv4": "%s/32" % (ipaddress.IPv4Address(lo4_base) + node_id),
            "ipv6": "%s/128" % (ipaddress.IPv6Address(lo6_base) + node_id),
        },
    }
    return ret


def main():

    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument(
        "-d", dest="debug", action="store_true", help="""Enable debug"""
    )
    parser.add_argument(
        "-q", dest="quiet", action="store_true", help="""Quiet output"""
    )
    parser.add_argument("--host", dest="hostname", help="""Hostname to configure for""")
    parser.add_argument(
        "--yaml",
        dest="yamldata",
        default=["config/common/generic.yaml"],
        type=str,
        nargs="*",
        help="""Location of YAML data file(s)""",
    )
    parser.add_argument(
        "--overlay",
        dest="overlay",
        default="default",
        type=str,
        help="""Type of lab setup (defined in config/common/generic.yaml 'overlays' dictionary, defaults to 'default')""",
    )
    parser.add_argument(
        "-o",
        dest="output",
        type=str,
        default=None,
        help="Output directory (default: overlay.build)",
    )

    args = parser.parse_args()
    if args.debug and args.quiet:
        parser.print_help()
        return

    if not args.hostname:
        parser.print_help()
        return

    if args.quiet:
        log.setLevel(logging.ERROR)
    elif args.debug:
        log.setLevel(logging.DEBUG)

    yamldata = "config/%s.yaml" % args.hostname
    if not os.path.exists(yamldata):
        log.error("Can't read config file %s" % yamldata)
        return
    log.info("Generating host %s" % (args.hostname))

    # Assemble the YAML dictionary
    yamldata = args.yamldata + [yamldata]
    log.debug("YAML data: %s" % yamldata)
    data = hiyapyco.load(*yamldata, method=hiyapyco.METHOD_MERGE, interpolate=True)
    if args.debug:
        log.debug("YAML merged configuration")
        print(hiyapyco.dump(data, default_flow_style=False))

    if not args.overlay in data["overlays"]:
        log.error("Overlay not defined, bailing.")
        return

    for node_type, ncount in data["lab"]["nodes"].items():
      for node_id in range(ncount):
        log.info("Generating for %s node %d" % (node_type, node_id))
        data["node"] = create_node(data["lab"], node_id, node_type)
        log.debug("node: %s" % data["node"])

        # Assemble a dictionary of tpl=>fn
        overlay = data["overlays"][args.overlay]
        common_root = overlay["path"] + "common/"
        type_root = overlay["path"] + node_type + "/"
        hostname_root = overlay["path"] + "hostname/" + data["node"]["hostname"] + "/"
        files = find([common_root, type_root, hostname_root])

        # Assemble a dictionary of fn=>output
        build = generate(files, data, args.debug)
        if not build:
            return

        # Emit the output (fn=>output)
        if not args.output and "build" in overlay:
            output = overlay["build"] + args.hostname + "/" + data["node"]["hostname"]
        else:
            output = "-"
        emit(build, output)

        # Remove all files/dirs not in (fn=>output)
        # prune(output, args.output)


if __name__ == "__main__":
    main()