#!/usr/bin/python3
import argparse
import json
import random
import socket
import sys

from collections import defaultdict

HOST = 'localhost'
PORT = None
ID_NEXT = random.randint(0, 262144)
JSON = False
EXITSTATUS = 0


class ErrorResponse(RuntimeError):
    pass


def make_new_request(method, params=None):
    global ID_NEXT
    reqid = ID_NEXT
    ID_NEXT += 1
    outj = {"id": reqid, "jsonrpc": "2.0", "method": method, "params": params or []}
    return outj


def sndrecv_json(sock, outj):
    verb = False
    msg = json.dumps(outj, indent=None).encode("utf8") + b'\n'
    if verb:
        print(f"Srv --> {msg[:2048]}")
    sock.send(msg)
    resp = bytearray()
    seen_nl = False
    while not seen_nl:
        chunk = sock.recv(4096)
        nl_index = chunk.find(b'\n')
        if nl_index > -1:
            seen_nl = True
            chunk = chunk[:nl_index]
        resp += chunk
    if verb:
        print(f"Srv <-- {resp[:2048]}")
    j = json.loads(resp.decode("utf8").strip())
    return j


def send_request(method, params=None):
    with socket.create_connection((HOST, PORT), timeout=10.0) as sock:
        req = make_new_request(method, params)
        reqid = req['id']
        j = sndrecv_json(sock, req)
        if j.get("error") or j.get("id") != reqid:
            raise ErrorResponse("Error response from Fulcrum:\n\n" + (json.dumps(j.get("error"), indent=4) or j))
        return j.get("result")


def send_request_batch(method_params_tuples):
    batch = []
    for method, params in method_params_tuples:
        batch.append(make_new_request(method, params))
    with socket.create_connection((HOST, PORT), timeout=10.0) as sock:
        jarray = sndrecv_json(sock, batch)
        ret = []
        if not isinstance(jarray, (list, tuple)):
            raise ErrorResponse("Error response from Fulcrum:\n\n" + jarray)
        # Sort the responses with respect to the 'id' so they are in the order caller specified
        jarray = sorted(jarray, key=lambda x: x.get('id') or -1)
        for item in jarray:
            if item.get("error") or 'result' not in item:
                raise ErrorResponse("Error response from Fulcrum:\n\n" + (json.dumps(item.get("error"), indent=4)
                                                                          or item))
            ret.append(item["result"])
        return ret


def main():
    global HOST, PORT, JSON, EXITSTATUS
    parser = argparse.ArgumentParser(prog="FulcrumAdmin", description="Fulcrum CLI admin tool")
    parser.add_argument('-p', type=int, metavar="port", nargs=1, required=True, help=f"Specify the port for the Fulcrum admin RPC service. This is a required argument.")
    parser.add_argument('-j', action='store_true', dest="json", help=f"Print the response from the server as JSON, not as formatted text")
    parser.add_argument('-H', type=str, metavar="host", nargs='?', default=HOST, help=f"Specify the host for the Fulcrum admin RPC service. Defaults to {HOST}.")
    subparsers = parser.add_subparsers(title="command", description="Select from one of the following commands:", dest="command")
    addpeer = subparsers.add_parser('addpeer', help="Add a peer to the server's list of peers")
    addpeer.add_argument('hostname', metavar='hostname', nargs=1, help="Hostname of peer.")
    addpeer.add_argument('-s', metavar='ssl_port', type=int, nargs='?', help="Peer's SSL port.")
    addpeer.add_argument('-t', metavar='tcp_port', type=int, nargs='?', help="Peer's TCP port.")
    ban = subparsers.add_parser('ban', help="Ban clients by ID and/or IP address")
    ban.add_argument('id_or_ip', metavar='ipaddress_or_id', nargs='+', help="Client ID or IP address to ban.")
    banpeer = subparsers.add_parser('banpeer', help="Ban peers by hostname suffix")
    banpeer.add_argument('hostnames', metavar='hostname', nargs='+', help="A hostname or hostname suffix e.g. somehost.com or *some.host.com.")
    bitcoind_throttle = subparsers.add_parser('bitcoind_throttle', help="Query or set server bitcoind_throttle setting")
    bitcoind_throttle.add_argument('param', metavar='param', nargs='*', help='The new desired setting. Specify 3 arguments to set this properly for: high low decay. Omit arguments to query.')
    clients = subparsers.add_parser('clients', help="Print information on all the currently connected clients", aliases=['sessions'])
    getinfo = subparsers.add_parser('getinfo', help="Get server information")
    kick = subparsers.add_parser('kick', help="Kick clients by ID and/or IP address")
    kick.add_argument('id_or_ip', metavar='ipaddress_or_id', nargs='+', help="Client ID or IP addresses to kick.")
    listbanned = subparsers.add_parser('listbanned', help="Print the list of banned IP addresses and peer hostnames", aliases=['banlist'])
    loglevel = subparsers.add_parser('loglevel', help="Set the server's logging verbosity")
    loglevel.add_argument('level', metavar='level', nargs=1, help="One of: 'normal', 'debug', or 'trace'")
    maxbuffer = subparsers.add_parser('maxbuffer', help="Query or set server max_buffer setting")
    maxbuffer.add_argument('bytes', metavar='bytes', type=int, nargs='?', help='The new desired max_buffer setting in bytes. Must be >= 64KB and <= 100MB. If omitted, then this script will just query the current value.')
    peers = subparsers.add_parser('peers', help="Print peering information")
    query = subparsers.add_parser('query', help="Query for balance, UTXO, and history information for one or more addresses")
    query.add_argument('address', metavar='address', nargs='+', help="Address or hex-encoded scriptPubKey to query")
    query.add_argument('-l', metavar='limit', type=int, nargs='?', default=1 << 63, help="UTXO and history output limit")
    rmpeer = subparsers.add_parser('rmpeer', help="Remove peers by hostname suffix")
    rmpeer.add_argument('hostnames', metavar='hostname', nargs='+', help="A hostname or hostname suffix e.g. somehost.com or *some.host.com.")
    simdjson = subparsers.add_parser('simdjson', help="Get or set the server's 'simdjson' (JSON parser) setting")
    simdjson.add_argument('enabled', type=int, nargs='?',
                          help='Flag used to enable or disable the simdjson JSON parser on the server (1=enabled,'
                               ' 0=disabled). If this option is omitted, then the current setting is queried.')
    stop = subparsers.add_parser('stop', help="Gracefully shut down the server", aliases=['shutdown'])
    unban = subparsers.add_parser('unban', help="Unban IP addresses")
    unban.add_argument('ips', metavar='ipaddress', nargs='+', help="Specify an existing banned IP address to unban.")
    unbanpeer = subparsers.add_parser('unbanpeer', help="Unban peers by hostname suffix")
    unbanpeer.add_argument('hostnames', metavar='hostname', nargs='+', help="Specify an existing peer ban to unban.")

    args = parser.parse_args()

    HOST = args.H
    PORT, = args.p
    JSON = args.json

    if PORT > 65535:
        sys.exit("Port argument must be < 65536")

    command = args.command
    if command == 'sessions': command = 'clients'  # Is there a better way to do this by referring back to the subparser above?? TODO
    if command == 'shutdown': command = 'stop'
    if command == 'banlist' : command = 'listbanned'
    command_params = tuple()
    response_handler = lambda r: json.dumps(r, indent = 4)  # default handler just pretty-prints the JSON
    my_send_request = send_request

    if command is None:
        print("Please specify a command to run.\n")
        parser.print_help()
        sys.exit(1)

    elif command == 'getinfo':
        orig_handler = response_handler
        def handler(r):
            # mogrify the uptime field to be more useful to humans
            if r.get('uptime'):
                r['uptime'] = formatTimeField(r['uptime'])
            return orig_handler(r)
        response_handler = handler

    elif command == 'stop' and not JSON:
        def handler(r):
            if isinstance(r, bool) and r:
                return "Fulcrum server is shutting down"
            else:
                global EXITSTATUS
                EXITSTATUS = 1
                return f"Unexpected response from Fulcrum server: {r}"
        response_handler = handler

    elif command == 'clients' and not JSON:
        response_handler = clients_handler

    elif command == 'maxbuffer':
        command_params = [args.bytes] if args.bytes else command_params
        if not JSON:
            def handler(x):
                if command_params:
                    return f"Server max_buffer setting -> {x}"
                else:
                    return f"Server max_buffer setting is: {x}"
            response_handler = handler

    elif command == 'simdjson':
        command_params = [bool(args.enabled)] if args.enabled is not None else command_params
        if not JSON:
            def handler(x):
                if command_params:
                    # set
                    assert isinstance(x, bool)
                    if x:
                        verb = "enabled" if command_params[0] else "disabled"
                        return f"Server simdjson parser {verb} successfully"
                    else:
                        verb = "enable" if command_params[0] else "disable"
                        return f"Failed to {verb} server simdjson parser"
                else:
                    # query
                    if x is not None:
                        return json.dumps(x, indent=4)
                    else:
                        return "Server simdjson parser is disabled"
            response_handler = handler

    elif command == 'bitcoind_throttle':
        command_params = args.param if len(args.param) else command_params
        if not JSON:
            def handler(x):
                names = defaultdict(lambda: "???")
                names.update({ 0: "high", 1: "low", 2: "decay" })
                if command_params:
                    return f"Server bitcoind_throttle setting -> " + ', '.join(names[n] + ' = ' + str(i) for n,i in enumerate(x))
                else:
                    return "Server bitcoind_throttle setting is: " + ', '.join(names[n] + ' = ' + str(i) for n,i in enumerate(x))
            response_handler = handler

    elif command in ('kick', 'ban', 'banpeer', 'unban', 'unbanpeer', 'rmpeer', 'loglevel'):
        extratxt = ''
        if command in ('kick', 'ban'):
            command_params = args.id_or_ip
        elif command in ('banpeer', 'unbanpeer', 'rmpeer'):
            command_params = args.hostnames
            if command == 'rmpeer':
                extratxt = ('''

Note: Removal of 'Good' peers is not guaranteed to keep them from re-peering
with the server in the near future. If you want the specified peer(s) to never
possibly peer with this Fulcrum server, then please use the 'banpeer' command.
''')
        elif command == 'unban':
            command_params = args.ips
        elif command == 'loglevel':
            levels = { 'normal': 0, 'debug' : 1, 'trace' : 2}
            if args.level[0] not in levels:
                print("level argument must be one of:", ', '.join(levels.keys()))
                sys.exit(1)
            command_params = [levels[args.level[0]]]
            extratxt = f' -> {args.level[0]}'

        if not JSON:
            def handler(r):
                if isinstance(r, bool) and r:
                    return f"{command} command submitted for: " + ', '.join([str(x) for x in command_params]) + extratxt
                else:
                    global EXITSTATUS
                    EXITSTATUS = 1
                    return f"Unexpected response from Fulcrum server: {r}"
            response_handler = handler

    elif command == 'listbanned' and not JSON:
        response_handler = listbanned_handler

    elif command == 'peers' and not JSON:
        response_handler = peers_handler

    elif command == 'addpeer':
        if not args.s and not args.t:
            print("addpeer requires at least one of the two port arguments (-s, -t), or both.")
            sys.exit(1)
        host = args.hostname[0]
        command_params = {
            'host' : host,
            'ssl'  : args.s or 0,
            'tcp'  : args.t or 0,
        }
        if not JSON:
            def handler(r):
                if isinstance(r, bool) and r:
                    return f"{command} command submitted for: {host}"
                else:
                    global EXITSTATUS
                    EXITSTATUS = 1
                    return f"Unexpected response from Fulcrum server: {r}"
            response_handler = handler

    elif command == 'query':
        limit = int(args.l)
        if limit <= 0:
            print('query requires limit be >= 1.')
            sys.exit(1)
        if not JSON:
            response_handler = lambda r: query_address_handler(r, limit=limit)
        command_params = args.address
        command = 'query_address'

        def send_request_override(method, params):
            batch = []
            if not JSON:
                # First request is the getinfo which tells us the "coin" and "chain" we are on, which is used
                # by the pretty formatter to output the balance as xx.xx BCH, etc.
                batch.append(('getinfo', []))
            for addr in params:
                batch.append((method, [addr]))
            return send_request_batch(batch)

        my_send_request = send_request_override

    try:
        print(response_handler(my_send_request(command, command_params)))
        sys.exit(EXITSTATUS)
    except OSError as e:
        print(f"Error communicating with the admin RPC port at {HOST}:{PORT}\n\n    {e}\n")
        sys.exit(1)
    except ErrorResponse as e:
        print(f"{e}")
        sys.exit(1)


def query_address_handler(response_list, limit):
    """Reproduce electrumx output verbatim. Note that it has quirks such that the history list is indexed at 0
    and the utxo list is indexed at 1 :( """
    assert isinstance(response_list, (list, tuple))
    assert isinstance(limit, int) and limit > 0
    lines = []

    coin_str = "(COIN UNITS)"
    if response_list:
        # Query the coin for proper units e.g. "BCH"
        info = response_list[0]
        if 'chain' in info and 'coin' in info:
            del response_list[0]  # pop first item which is the getinfo reply
            coin = info['coin']
            chain = info['chain']
            prefix = ''
            lchain = chain.lower()
            if not lchain.startswith('main'):
                if lchain.startswith('scale') or lchain.startswith('sig'):
                    prefix = 's'
                elif lchain.startswith('reg'):
                    prefix = 'r'
                else:
                    prefix = 't'
            coin_str = prefix + coin

    def format_coin_str(value):
        neg = value < 0
        sign = '-' if neg else ''
        if neg:
            value = -value
        integral_part = value // int(1e8)
        frac_part_str = f'{value % int(1e8):08d}'
        while len(frac_part_str) > 1 and frac_part_str[-1] == '0':  # Trim trailing zeroes
            frac_part_str = frac_part_str[:-1]
        return f'{sign}{integral_part:,d}.{frac_part_str} {coin_str}'

    for r in response_list:
        assert isinstance(r, dict)
        assert ('address' in r or 'script' in r) and 'balance' in r and 'history' in r and 'unspent' in r
        if 'address' in r:
            address = r['address']
            lines.append(f'Address: {address}')
        elif 'script' in r:
            script_hex = r['script']
            lines.append(f"Script: {script_hex}")
        history = r['history']
        if history:
            for n, d in enumerate(history):
                height = d['height']
                tx_hash = d['tx_hash']
                lines.append(f'History #{n:,d}: height {height:,d} tx_hash {tx_hash}')
                if n + 1 >= limit:
                    break
        else:
            lines.append("No history found")
        utxos = r['unspent']
        if utxos:
            for n, d in enumerate(utxos, start=1):
                tx_hash = d['tx_hash']
                tx_pos = d['tx_pos']
                height = d['height']
                value = d['value']
                lines.append(f'UTXO #{n:,d}: tx_hash {tx_hash} tx_pos {tx_pos:,d} height {height:,d} value {value:,d}')
                if n >= limit:
                    break
        else:
            lines.append("No UTXOs found")
        bal = r['balance']
        balance_value = int(bal['confirmed']) + int(bal['unconfirmed'])
        lines.append(f'Balance: {format_coin_str(balance_value)}')
    return '\n'.join(lines)


def clients_handler(r):
    def badResp():
        global EXITSTATUS
        EXITSTATUS = 1
        return f"Unexpected resposne from Fulcrum server: {r!r}"
    if not isinstance(r, (list, tuple)):
        return badResp()
    lines = []
    line = ("ID","IP:PORT","Typ","UAgent","ProtocolVer","Subs","HdrSub?","ReqRcv","RespSent","RecvBytes","SentBytes","TxsSent","Notifs","ErrorCt","Elapsed")
    maxfields = defaultdict(int)
    for i,c in enumerate(line):
        maxfields[i] = max(maxfields[i], len(c))
    lines.append(line)
    for serverdict in r:
        if not isinstance(serverdict, dict):
            return badResp()
        for server_name, subdict in serverdict.items():
            if not isinstance(subdict, dict):
                # in case we add stuff to here that is not a server dict some day
                continue;
            is_wss = server_name.lower().startswith('wsss')
            is_ws = not is_wss and server_name.lower().startswith('wss')
            if not is_wss and not is_ws:
                is_ssl = server_name.lower().startswith('ssl')
                typnam = 'SSL' if is_ssl else 'TCP'
            else:
                typnam = 'WS' if is_ws else 'WSS'

            for client in subdict.get('clients', []):
                for cname, cdict in client.items():
                    line = (
                        cdict.get('id', -1),
                        cdict.get('remote', '?'),
                        typnam,
                        cdict.get('userAgent', 'Unk'),
                        cdict.get('version', ['?'])[-1],
                        cdict.get('nSubscriptions', -1),
                        'Y' if cdict.get('isSubscribedToHeaders') else 'N',
                        cdict.get('nRequestsRcv', -1),
                        cdict.get('nResultsSent', -1),
                        cdict.get('nBytesReceived', -1),
                        cdict.get('nBytesSent', -1),
                        cdict.get('nTxSent', -1),
                        cdict.get('nNotificationsSent', -1),
                        cdict.get('nErrorsSent', -1),
                        formatTimeField( cdict.get('connectedTime', '-') )
                    )
                    line = [str(x) for x in line]
                    for i,c in enumerate(line):
                        maxfields[i] = max(maxfields[i], len(c))
                    lines.append(line)
    for i,line in enumerate(list(lines)):
        line = list(line)
        for j,c in enumerate(line):
            line[j] = c.ljust(maxfields[j])
        lines[i] = '  '.join(line)
    return '\n'.join(lines) + '\n'


def listbanned_handler(r):
    def badResp():
        global EXITSTATUS
        EXITSTATUS = 1
        return f"Unexpected resposne from Fulcrum server: {r!r}"
    if not isinstance(r, (dict,)):
        return badResp()
    lines = []
    line = ("IP","AgeSecs","RejectedConnections",)
    maxfields = defaultdict(int)
    for i,c in enumerate(line):
        maxfields[i] = max(maxfields[i], len(c))
    lines.append(line)
    for ipaddr, subdict in r.get('Banned_IPAddrs', {}).items():
        if not isinstance(subdict, dict):
            return badResp()
        line = (
            ipaddr,
            subdict.get('age_secs', -1),
            subdict.get('connections_rejected', 0),
        )
        line = [str(x) for x in line]
        for i,c in enumerate(line):
            maxfields[i] = max(maxfields[i], len(c))
        lines.append(line)
    for i,line in enumerate(list(lines)):
        line = list(line)
        for j,c in enumerate(line):
            line[j] = c.ljust(maxfields[j])
        lines[i] = '  '.join(line)
    clientPart = '~Client Bans~\n' + '\n'.join(lines)
    maxFields = defaultdict(int)
    lines = []
    line = ("HostName","AgeSecs",)
    for i,c in enumerate(line):
        maxfields[i] = max(maxfields[i], len(c))
    lines.append(line)
    for hostname, subdict in r.get('Banned_Peers', {}).items():
        if not isinstance(subdict, dict):
            return badResp()
        line = (
            hostname,
            subdict.get('age_secs', -1),
        )
        line = [str(x) for x in line]
        for i,c in enumerate(line):
            maxfields[i] = max(maxfields[i], len(c))
        lines.append(line)
    for i,line in enumerate(list(lines)):
        line = list(line)
        for j,c in enumerate(line):
            line[j] = c.ljust(maxfields[j])
        lines[i] = '  '.join(line)
    peerPart = '~Peer Bans~\n' + '\n'.join(lines)
    return clientPart + '\n\n' + peerPart


def peers_handler(r):
    def badResp():
        global EXITSTATUS
        EXITSTATUS = 1
        return f"Unexpected resposne from Fulcrum server: {r!r}"

    if not isinstance(r, (dict,)):
        return badResp()
    lines = []
    line = ("Hostname","IP","Status","TCP","SSL","Version","ProtoMin","ProtoMax","Elapsed","Message","RetryPeriod")
    maxfields = defaultdict(int)
    for i,c in enumerate(line):
        maxfields[i] = max(maxfields[i], len(c))
    lines.append(line)
    retryTimes = {
        "bad" : r.get('activeTimers', {}).get("badPeerRetry", 0),
        "failed" : r.get('activeTimers', {}).get("failedPeerRetry", 0),
    }
    # peers (good, connected peers)
    peerdict = r.get('peers')
    if not isinstance(peerdict, dict):
        return badResp()
    for hostname, d in peerdict.items():
        line = (
            hostname,
            d.get('addr','-'),
            'Good' if d.get('verified') == True else 'Verifying',
            d.get('tcp_port') if d.get('tcp_port') else '-',
            d.get('ssl_port') if d.get('ssl_port') else '-',
            d.get('server_version', 'Unk'),
            d.get('protocol_min', '-'),
            d.get('protocol_max', '-'),
            formatTimeField( d.get('connectedTime', '-') ),
            '-', '-'
        )
        line = [str(x) for x in line]
        for i,c in enumerate(line):
            maxfields[i] = max(maxfields[i], len(c))
        lines.append(line)
    # next, do the failed, bad, and queued peers
    for status in ('failed', 'bad', 'queued'):
        d = r.get(status)
        if not d:
            continue  # skip empties
        if not isinstance(d, dict):
            return badResp()
        for hostname, l in d.items():
            try:
                ip, tcp, ssl, subver, pver, time, msg = l
            except ValueError:
                pass
            line = (
                hostname, ip, status.title(), tcp if tcp else '-', ssl if ssl else '-',
                subver, pver, pver, formatTimeField(time), msg, '~' + str(retryTimes.get(status, 0)//1000) + 's'
            )
            line = [str(x) for x in line]
            for i,c in enumerate(line):
                maxfields[i] = max(maxfields[i], len(c))
            lines.append(line)

    for i,line in enumerate(list(lines)):
        line = list(line)
        for j,c in enumerate(line):
            line[j] = c.ljust(maxfields[j])
        lines[i] = '  '.join(line)
    return '\n'.join(lines) + '\n'


def formatTimeField(s):
    if s is None:
        return '-'  # transform None to '-' since it takes up fewer characters
    secs_per_hour = 60.0 * 60.0
    if isinstance(s, str):
        to_secs_factor = 1.0
        if s.endswith("hours"):  # one of the "xx hours" fields
            to_secs_factor = secs_per_hour
        elif s.endswith("secs"):
            pass
        else:
            return s # unknown field type, just return it verbatim
        try:
            s = float(s.split()[0].strip()) * to_secs_factor  # transform hours to seconds
        except (ValueError, TypeError, IndexError):
            return s
        del to_secs_factor
    # at this point s should be in seconds.. if not, give up and just return it
    if not isinstance(s, (int, float)):
        return s
    # now turn seconds to minutes, hours, days, months, years
    secs_per_day = secs_per_hour * 24.0
    secs_per_year = 365.0 * secs_per_day
    secs_per_month = 30.0 * secs_per_day
    secs_per_minute = 60.0
    if s >= secs_per_year:
        return f'{s/secs_per_year:0.3f} years'
    elif s >= secs_per_month:
        return f'{s/secs_per_month:0.3f} months'
    elif s >= secs_per_day:
        return f'{s/secs_per_day:0.2f} days'
    elif s >= secs_per_hour:
        return f'{s/secs_per_hour:0.2f} hours'
    elif s >= secs_per_minute:
        return f'{s/secs_per_minute:0.1f} mins'
    else:
        return f'{s:0.1f} secs'


if __name__ == '__main__':
    main()
