#!/usr/bin/python3.6 -s
# GPL. (C) 2020 Paolo Patruno.

# This program is free software; you can redistribute it and/or modify 
# it under the terms of the GNU General Public License as published by 
# the Free Software Foundation; either version 2 of the License, or 
# (at your option) any later version. 
# 
# This program is distributed in the hope that it will be useful, 
# but WITHOUT ANY WARRANTY; without even the implied warranty of 
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
# GNU General Public License for more details. 
# 
# You should have received a copy of the GNU General Public License 
# along with this program; if not, write to the Free Software 
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA 
#

import signal
import os
import re
import json
from datetime import datetime
import dballe
import paho.mqtt.client as mqtt
import pika
import logging,traceback
import time,queue
import socket

os.environ['DJANGO_SETTINGS_MODULE'] = 'rmap.settings'
import django
django.setup()


from rmap import daemon
import rmap.settings
from rmap.rmap_core import amqpConsumerProducer
from rmap.rmap_core import Message

TOPIC_RE_V0 = re.compile((
    r'^.*/'
    r'(?P<ident>[^/]+)/'
    r'(?P<lon>[0-9-]+),'
    r'(?P<lat>[0-9-]+)/'
    r'(?P<rep>[^/]+)/'
    r'(?P<pind>[0-9]+|-),'
    r'(?P<p1>[0-9]+|-),'
    r'(?P<p2>[0-9]+|-)/'
    r'(?P<lt1>[0-9]+|-),'
    r'(?P<lv1>[0-9]+|-),'
    r'(?P<lt2>[0-9]+|-),'
    r'(?P<lv2>[0-9]+|-)/'
    r'(?P<var>B[0-9]{5})$'
))

TOPIC_RE_V1 = re.compile((
    r'^1/[^/]+/'
    r'(?P<user>[^/]+)/'
    r'(?P<ident>[^/]*)/'
    r'(?P<lon>[0-9-]+),'
    r'(?P<lat>[0-9-]+)/'
    r'(?P<rep>[^/]+)/'
    r'(?P<pind>[0-9]+|-),'
    r'(?P<p1>[0-9]+|-),'
    r'(?P<p2>[0-9]+|-)/'
    r'(?P<lt1>[0-9]+|-),'
    r'(?P<lv1>[0-9]+|-),'
    r'(?P<lt2>[0-9]+|-),'
    r'(?P<lv2>[0-9]+|-)/'
    r'(?P<var>B[0-9]{5})$'
))


class MyMQTTClass(mqtt.Client):

    def parse_topic_v0(self,topic):
        match = TOPIC_RE_V0.match(topic)
        if match is None:
            logging.info("SKIP topic V0: {}".format(topic))
            return None
        else:
            g = match.groupdict()
            return {
                "ident": None if g["ident"] == "-" else g["ident"],
                "lon": int(g["lon"]),
                "lat": int(g["lat"]),
                "rep_memo": g["rep"],
                "level": (
                    None if g["lt1"] == "-" else int(g["lt1"]),
                    None if g["lv1"] == "-" else int(g["lv1"]),
                    None if g["lt2"] == "-" else int(g["lt2"]),
                    None if g["lv2"] == "-" else int(g["lv2"]),
                ),
                "trange": (
                    None if g["pind"] == "-" else int(g["pind"]),
                    None if g["p1"] == "-" else int(g["p1"]),
                    None if g["p2"] == "-" else int(g["p2"]),
                ),
                "var": g["var"],
            }


    def parse_topic_v1(self,topic,userasident=True):
        match = TOPIC_RE_V1.match(topic)
        if match is None:
            logging.info("SKIP topic V1: {}".format(topic))
            return None
        else:
            g = match.groupdict()
            m = {
                "lon": int(g["lon"]),
                "lat": int(g["lat"]),
                "rep_memo": g["rep"],
                "level": (
                    None if g["lt1"] == "-" else int(g["lt1"]),
                    None if g["lv1"] == "-" else int(g["lv1"]),
                    None if g["lt2"] == "-" else int(g["lt2"]),
                    None if g["lv2"] == "-" else int(g["lv2"]),
                ),
                "trange": (
                    None if g["pind"] == "-" else int(g["pind"]),
                    None if g["p1"] == "-" else int(g["p1"]),
                    None if g["p2"] == "-" else int(g["p2"]),
                ),
                "var": g["var"],
            }
            
            if userasident:
                if g["ident"] == "":
                    m["ident"] = None if g["user"] == "" else g["user"]
                else:
                    m["ident"] = g["ident"] 
            else:
                # TODO manage user to migrate to correct user management and put user as constant station data
                #m["user"] =  None if g["user"]  == "" else g["user"],
                m["ident"] = None if g["ident"] == "" else g["ident"]

            return m
        
    def parse_payload(self,payload):
        return json.loads(payload.decode("utf-8"))

    def parse_message(self,topic, payload,version="v0",userasident=True):
        if version == "v0":
            t = self.parse_topic_v0(topic)
        else:
            t = self.parse_topic_v1(topic,userasident)

        if t is None:
            return None

        m = self.parse_payload(payload)
        msg = t.copy()
        msg["value"] = m["v"]
        if all([
                t["level"] != (None, None, None, None),
                t["trange"] != (None, None, None),
        ]):
            if "t" in m:
                msg["datetime"] = datetime.strptime(m["t"], "%Y-%m-%dT%H:%M:%S")
            else:
                msg["datetime"] = datetime.now()
        else:
            msg["datetime"] = None
    
        if "a" in m:
            msg["attributes"] = m["a"]
        else:
            msg["attributes"] = {}
            
        return msg

    def on_connect(self,client, userdata, flags, rc):
        for topic in userdata["topics"]:
            client.subscribe(topic,qos=userdata.get("qos",0))

    def on_message(self,client, userdata, message):

        if userdata["send_queue"]:
            send_queue=userdata["send_queue"]

        version = userdata.get("version","v0")
        userasident = userdata.get("userasident",True)
        skipnetworks = userdata.get("skipnetworks",[])
        
        logging.info('Message: topic '+str(message.topic) +" payload "+str(message.payload))
        
        try:
            m = self.parse_message(message.topic, message.payload,version=version,userasident=userasident)
            if m is None:
                return

            if (m["rep_memo"] in skipnetworks):
                return
                
            msg = dballe.Message("generic")
            if m["ident"] is not None:
                msg.set_named("ident", dballe.var("B01011", m["ident"]))

            msg.set_named("longitude", dballe.var("B06001", m["lon"]))
            msg.set_named("latitude", dballe.var("B05001", m["lat"]))
            msg.set_named("rep_memo", dballe.var("B01194", m["rep_memo"]))

            if m["datetime"] is not None:
                msg.set_named("year", dballe.var("B04001", m["datetime"].year))
                msg.set_named("month", dballe.var("B04002", m["datetime"].month))
                msg.set_named("day", dballe.var("B04003", m["datetime"].day))
                msg.set_named("hour", dballe.var("B04004", m["datetime"].hour))
                msg.set_named("minute", dballe.var("B04005", m["datetime"].minute))
                msg.set_named("second", dballe.var("B04006", m["datetime"].second))

            var = dballe.var(m["var"], m["value"])

            for b, v in m["attributes"].items():
                var.seta(dballe.var(b, v))

            msg.set(m["level"], m["trange"], var)

            exporter = dballe.Exporter(encoding="BUFR")
            totalbody=exporter.to_binary(msg)
            #sys.stdout.buffer.write(exporter.to_binary(msg))
            #sys.stdout.flush()

            send_queue.put_nowait(Message(totalbody))

        except Exception:
            logging.error(traceback.format_exc())


class  mydaemon(daemon.Daemon):

    def optionparser(self):
        op = super(mydaemon, self).optionparser()
        op.add_option("-d", "--datalevel",dest="datalevel", help="sample or report: define the istance to run: select topic, dns,logfile, errorfile and lockfile (default %default)", default="sample")
        op.add_option("-s", "--stationtype",dest="stationtype", help="fixed or mobile: define the istance to run: select topic, dns,logfile, errorfile and lockfile (default %default)", default="fixed")
        op.add_option("-v", "--rmapversion",dest="rmapversion", help="RMAP MQTT version (v0/v1) (default %default)", default="v0")
        op.add_option("-u", "--nouserasident",dest="userasident", action="store_false", help="do not use user as ident for fixed station too (default use ident)", default=True)
        #op.add_option("-t", "--topic",dest="topic", help="topic root to subscribe on mqtt broker (default %default)", default="rmap")
        return op 	  				 

mqtt2amqpd = mydaemon(
        stdin="/dev/null",
        stdout=rmap.settings.logfilemqtt2amqpd,
        stderr=rmap.settings.errfilemqtt2amqpd,
        pidfile=rmap.settings.lockfilemqtt2amqpd,
        user=rmap.settings.usermqtt2amqpd,
        group=rmap.settings.groupmqtt2amqpd
)

# catch signal to terminate the process
class GracefulKiller:
    kill_now = False
    def __init__(self):
        signal.signal(signal.SIGINT, self.exit_gracefully)
        signal.signal(signal.SIGTERM, self.exit_gracefully)

    def exit_gracefully(self,signum, frame):
        self.kill_now = True

    def keyboard_interrupt(self):
        self.kill_now = True

    def terminate(self):
        self.kill_now = True

def main(self):

    import os,sys,time
    import logging.handlers
    import subprocess

    #arm the signal handler
    killer = GracefulKiller()

    # configure the logger
    formatter=logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s",datefmt="%Y-%m-%d %H:%M:%S")
    handler = logging.handlers.RotatingFileHandler(self.options.stdout, maxBytes=5000000, backupCount=10)
    handler.setFormatter(formatter)
    
    # Add the log message handler to the root logger
    logging.getLogger().addHandler(handler)
    logging.getLogger().setLevel(logging.INFO)

    logging.info('Starting up mqtt2amqpd')

    topicdict={
        "sample":{"fixed":{},"mobile":{}},
        "report":{"fixed":{},"mobile":{}},
    }
    topicdict["sample"]["fixed"]["v0"]  = "{}/+/+/{}/#".format(rmap.settings.topicsample,"fixed")
    topicdict["sample"]["mobile"]["v0"] = "{}/+/+/{}/#".format(rmap.settings.topicsample,"mobile")
    topicdict["report"]["fixed"]["v0"]  = "{}/+/+/{}/#".format(rmap.settings.topicreport,"fixed")
    topicdict["report"]["mobile"]["v0"] = "{}/+/+/{}/#".format(rmap.settings.topicreport,"mobile")

    topicdict["sample"]["fixed"]["v1"]  = "1/{}/+/+/+/{}/#".format(rmap.settings.topicsample,"+")
    topicdict["sample"]["mobile"]["v1"] = "1/{}/+/+/+/{}/#".format(rmap.settings.topicsample,"mobile")
    topicdict["report"]["fixed"]["v1"]  = "1/{}/+/+/+/{}/#".format(rmap.settings.topicreport,"+")
    topicdict["report"]["mobile"]["v1"] = "1/{}/+/+/+/{}/#".format(rmap.settings.topicreport,"mobile")

    
    if not (self.options.datalevel in list(topicdict.keys())):
        logging.error('Invalid datalevel')
        sys.stdout.write("Invalid datalevel\n")
        return False

    if not (self.options.stationtype in list(topicdict[self.options.datalevel].keys())):
        logging.error('Invalid stationtype')
        sys.stdout.write("Invalid stationtype\n")
        return False

    if not (self.options.rmapversion in list(topicdict[self.options.datalevel][self.options.stationtype].keys())):
        logging.error('Invalid rmapversion')
        sys.stdout.write("Invalid rmapversion\n")
        return False

    topic = topicdict[self.options.datalevel][self.options.stationtype][self.options.rmapversion]
    logging.info('Topic: %s'% topic)

    # in v1 with station type is fixed I subscribe to all networks but reject mobile as special use
    if (self.options.rmapversion == "v1" and self.options.stationtype == "fixed"):                
        skipnetworks=["mobile"]
    else:
        skipnetworks=[]                
    for skipnetwork in skipnetworks:
        logging.info('Skip network: %s'% skipnetwork)
    
    mqtthost=rmap.settings.mqtthost
    mqttuser=rmap.settings.mqttuser
    mqttpassword=rmap.settings.mqttpassword    
    amqpuser=rmap.settings.amqpuser
    amqppassword=rmap.settings.amqppassword
    amqphost=rmap.settings.amqphost
    exchange="rmap.mqtt.bufr."+self.options.datalevel+"_"+self.options.stationtype
    send_queue = queue.Queue(maxsize=1000000)

    amqp=amqpConsumerProducer(host=amqphost,exchange=exchange,user=amqpuser,password=amqppassword,send_queue=send_queue)
    amqp.start()

    # If you want to use a specific client id, use
    # mqttc = MyMQTTClass("client-id")
    # but note that the client id must be unique on the broker. Leaving the client
    # id parameter empty will generate a random id for you.

    mqttclient = MyMQTTClass(
        userdata=
        {
            "topics": [topic],
            "send_queue": send_queue,
            "qos":1,
            "version":self.options.rmapversion,
            "userasident":self.options.userasident,
            "skipnetworks":skipnetworks
        },
        client_id=socket.gethostname()+"/"+mqttuser+"/"+self.options.datalevel+"/"+self.options.stationtype+"/"+self.options.rmapversion,
        clean_session=False
        # qos=1 required by clean_session=False
    )


    if mqttuser:
        mqttclient.username_pw_set(mqttuser, mqttpassword)

    mqttclient.enable_logger()
    mqttclient.loop_start()
    mqttclient.connect(host=mqtthost)

    # infinite loop
    while True:
        try:
            time.sleep(3) # do nothing
        except Exception as exception:
            logging.error('Exception occured: ' + str(exception))
            logging.error(traceback.format_exc())
            killer.terminate()
            
        # terminate on keyboard interrupt
        except KeyboardInterrupt:
            sys.stdout.write("keyboard interrupt\n")
            logging.info("keyboard interrupt\n")
            killer.keyboard_interrupt()

        # terminate without error
        # no exception was raised
        #logging.info('work finished')


        if (not amqp.is_alive()):
            logging.info("amqp aborted\n")
            killer.terminate()            
        
        # check if we have to terminate together with other exceptions
        if killer.kill_now:
            logging.info("killed by signal\n")

            mqttclient.disconnect()
            mqttclient.loop_stop()
            logging.debug("MQTT connection closed")
            
            amqp.terminate()
            amqp.join()
            logging.debug("AMQP terminated")

            return False
    

if __name__ == '__main__':

    import sys, os
    
    mqtt2amqpd.cwd=os.getcwd()

    if mqtt2amqpd.service():

        sys.stdout.write("Daemon started with pid %d\n" % os.getpid())

        main(mqtt2amqpd)  # (this code was run as script)
            
        for proc in mqtt2amqpd.procs:
            proc.wait()

        sys.stdout.write("Daemon stoppped\n")
        sys.exit(0)
