#!/usr/bin/env python3

debug = False

import socket, time, logging
import xml.etree.ElementTree as ET

DONTUSEMEDUSA = False
SETZOOMVALUES = True

class MedusaConnectionError(Exception):
    """Raised if connection to Medusa is refused"""

class Medusa:
    'Control Medusa from Python'

    MEDUSASTATUSPORT = 21200
    MEDUSACOMMANDPORT = 21100
    MEDUSASERVER = 'medusa-srv0.atnf.csiro.au'
    #MEDUSASERVER = 'localhost'

    def __init__(self):
        self.statusXML = None
        self.log = logging.getLogger(__name__)
        self.streaming = True

    def __del__(self):
        #if self.sock is not None:
        #    self.sock.close()
        print('*****Deleting Medusa object*****')
        pass

    def requestStatus(self):
        if DONTUSEMEDUSA:
            self.warning("WARNING NOT CONTROLLING MEDUSA")
            return()

        # Form XML with configuration request
        command_root = ET.Element("tcs_state_request")
        ET.SubElement(command_root, "requestor").text = "search page"
        ET.SubElement(command_root, "type").text = "state"
        command = ET.tostring(command_root, encoding="ISO-8859-1")
        command += b'\r\n'

        try:
            # Connect to Medusa and request status Return XML object
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect((self.MEDUSASERVER, self.MEDUSASTATUSPORT))
        except socket.gaierror as e:
            raise MedusaConnectionError("sock.connect failed: Unknown address {}".format(self.MEDUSASERVER))
        except ConnectionRefusedError as e:
            raise MedusaConnectionError("sock.connect failed to connect to {}".format(self.MEDUSASERVER))

        # Send XML to server and get response
        sock.send(command)
        time.sleep(1)
        resp = sock.recv(32000)
        sock.close()

        spip_setup = ET.XML(resp)

        self.statusXML = spip_setup

        return True

    def getStatus(self):
        if self.statusXML is None:
            return None
        beam = self.statusXML.find("beam")
        return beam.attrib['state']

    def nStream(self):
        instConf = self.statusXML.find("instrument_configuration")
        return int(instConf.attrib['nstream'])

    def createConfig(self, subbands, exper, cal, bits, port):
        subs = list(subbands.keys())
        subs.sort(key=int) # Make sure in order

        nStream = len(subs)

        root = ET.Element("obs_cmd")

        command = ET.SubElement(root, "command")
        command.text = "configure"

        beam_config = ET.SubElement(root, "beam_configuration")
        ET.SubElement(beam_config, "nbeam", key="NBEAM").text = "1"
        ET.SubElement(beam_config, "beam_state_0", key="BEAM_STATE_0", name="1").text = "1"

        stream_config = ET.SubElement(root, "stream_configuration")
        ET.SubElement(stream_config, "nstream", key="NSTREAM").text = str(nStream)
        for i in range(nStream):
            ET.SubElement(stream_config, "active", key="STREAM_ACTIVE").text = "1"

        source = ET.SubElement(root, "source_parameters")
        ET.SubElement(source, "name", epoch="J2000", key="SOURCE").text = "VLBI"
        ET.SubElement(source, "ra", key="RA", units="hh:mm:ss").text = "00:00:00"
        ET.SubElement(source, "dec", key="DEC", units="dd:mm:ss").text = "00:00:00"

        obs_parm = ET.SubElement(root, "observation_parameters")
        ET.SubElement(obs_parm, "observer", key="OBSERVER").text = "VLBI"
        ET.SubElement(obs_parm, "project_id", key="PID").text = exper
        ET.SubElement(obs_parm, "tobs", key="TOBS").text = "60"
        ET.SubElement(obs_parm, "utc_start", key="UTC_START").text = "None"
        ET.SubElement(obs_parm, "utc_stop", key="UTC_STOP").text = "None"

        cal_parm = ET.SubElement(root, "calibration_parameters")
        if cal['calOn']:
            ET.SubElement(cal_parm, "signal", key="CAL_SIGNAL").text = "1"
        else:
            ET.SubElement(cal_parm, "signal", key="CAL_SIGNAL").text = "0"
        ET.SubElement(cal_parm, "freq", key="CAL_FREQ", units="Hertz").text = str(cal['frequency'])
        ET.SubElement(cal_parm, "phase", key="CAL_PHASE").text = str(cal['phase'])
        ET.SubElement(cal_parm, "duty_cycle", key="CAL_DUTY_CYCLE").text = "{:.2f}".format(float(cal['dutyCycle'])/100.0)
        ET.SubElement(cal_parm, "epoch", key="CAL_EPOCH", units="YYYY-DD-MM-HH:MM:SS+0").text = str(cal['calEpoch'])
        ET.SubElement(cal_parm, "tsys_avg_time", key="TSYS_AVG_TIME", units="seconds").text = str(cal['tsysAvgTime'])
        ET.SubElement(cal_parm, "tsys_freq_resolution", key="TSYS_FREQ_RES", units="MHz").text = str(cal['tsysFreqRes'])

        threadID = 0

        for s in subs:
            stream = ET.SubElement(root, "stream{}".format(s))

            custom_param = ET.SubElement(stream, "custom_parameters")
            ET.SubElement(custom_param, "adaptive_filter", key="ADAPTIVE_FILTER").text = "0"
            ET.SubElement(custom_param, "adaptive_filter_epsilon", key="ADAPTIVE_FILTER_EPSILON").text = "0.1"
            ET.SubElement(custom_param, "adaptive_filter_nchan", key="ADAPTIVE_FILTER_NCHAN").text = "128"
            ET.SubElement(custom_param, "adaptive_filter_nsamp", key="ADAPTIVE_FILTER_NSAMP").text = "1024"
            ET.SubElement(custom_param, "schedule_block_id", key="SCHED_BLOCK_ID").text = "0"
            ET.SubElement(custom_param, "scan_id", key="SCAN_ID").text = "0"
            ET.SubElement(custom_param, "raw_baseband", key="RECORD_RAW_BASEBAND").text = "0"

            processing_modes = ET.SubElement(stream, "processing_modes")
            ET.SubElement(processing_modes, "fold", key="PERFORM_FOLD").text = "0"
            ET.SubElement(processing_modes, "search", key="PERFORM_SEARCH").text = "0"
            ET.SubElement(processing_modes, "continuum", key="PERFORM_CONTINUUM").text = "0"
            ET.SubElement(processing_modes, "spectral_line", key="PERFORM_SPECTRAL_LINE").text = "0"

            if subbands[s][2] is not None:
                ET.SubElement(processing_modes, "vlbi", key="PERFORM_VLBI").text = "1"
            else:
                ET.SubElement(processing_modes, "vlbi", key="PERFORM_VLBI").text = "0"

            ET.SubElement(processing_modes, "baseband", key="PERFORM_BASEBAND").text = "0"

            zoom_proc_param = ET.SubElement(stream, "zoom_processing_parameters")

            def createZoom(zoom_proc_param, zoomID, active, bandwidth, freq, mode):
                ET.SubElement(zoom_proc_param, "active{:d}".format(zoomID), key="ZOOM{:d}_ACTIVE".format(zoomID)).text = str(active)
                ET.SubElement(zoom_proc_param, "bandwidth{:d}".format(zoomID), key="ZOOM{:d}_BW".format(zoomID), units="MHz").text = "{:.0f}".format(bandwidth)
#                ET.SubElement(zoom_proc_param, "frequency{:d}".format(zoomID), key="ZOOM{:d}_FREQUENCY".format(zoomID), units="MHz").text = "{:.0f}".format(freq+1000)
                ET.SubElement(zoom_proc_param, "frequency{:d}".format(zoomID), key="ZOOM{:d}_FREQUENCY".format(zoomID), units="MHz").text = "{:.0f}".format(freq)
                ET.SubElement(zoom_proc_param, "mode{:d}".format(zoomID), key="ZOOM{:d}_MODE".format(zoomID)).text = str(mode)

            if subbands[s][2] is not None and subbands[s][2]!=128.0: # Don't create zoom for 128 MHz mode
                createZoom(zoom_proc_param, 1, 1, subbands[s][3], subbands[s][2], "vlbi")
            else:
                createZoom(zoom_proc_param, 1, 0, 4, 832, "baseband")

            if subbands[s][4] is not None:
                createZoom(zoom_proc_param, 2, 1, subbands[s][5], subbands[s][4], "vlbi")
            else:
                createZoom(zoom_proc_param, 2, 0, 4, 832, "baseband")

            vlbi_proc_param = ET.SubElement(stream, "vlbi_processing_parameters")
            ET.SubElement(vlbi_proc_param, "output_nbit", key="VLBI_OUTNBIT").text = str(bits)
            ET.SubElement(vlbi_proc_param, "output_encoding", key="VLBI_ENCODING").text = "offset_binary"
            ET.SubElement(vlbi_proc_param, "dest_ip", key="VLBI_DEST_IP").text = "10.17.10.1"  # NEED TO SET
            ET.SubElement(vlbi_proc_param, "dest_port", key="VLBI_DEST_PORT").text = str(port)
            if subbands[s][2] is not None and subbands[s][3]==128.0:
                ET.SubElement(vlbi_proc_param, "vdif_thread_id", key="VLBI_VDIF_THREAD_ID").text = str(threadID)
                threadID += 1
            else:
                ET.SubElement(vlbi_proc_param, "vdif_thread_id", key="VLBI_VDIF_THREAD_ID").text = "0"
            ET.SubElement(vlbi_proc_param, "dest_protocol", key="VLBI_DEST_PROTOCOL").text = "udp" # NEED TO SET
            ET.SubElement(vlbi_proc_param, "dest_sequence_number", key="VLBI_DEST_SEQ_NO").text = "1" # NEED TO SET
            ET.SubElement(vlbi_proc_param, "polarisation", key="VLBI_POLARISATION").text = "linear"

            if SETZOOMVALUES:
                for z in range(1,3):  # 1,2
                    ET.SubElement(vlbi_proc_param, "zoom{}_output_nbit".format(z), key="ZOOM{}_VLBI_OUTNBIT".format(z)).text = str(bits)
                    ET.SubElement(vlbi_proc_param, "zoom{}_output_encoding".format(z), key="ZOOM{}_VLBI_ENCODING".format(z)).text = "offset_binary"
                    ET.SubElement(vlbi_proc_param, "zoom{}_dest_ip".format(z), key="ZOOM{}_VLBI_DEST_IP".format(z)).text = "10.17.10.1"  # NEED TO SET
                    ET.SubElement(vlbi_proc_param, "zoom{}_dest_port".format(z), key="ZOOM{}_VLBI_DEST_PORT".format(z)).text = str(port)
                    if subbands[s][z*2] is not None and subbands[s][z*2+1]!=128.0:
                        ET.SubElement(vlbi_proc_param, "zoom{}_vdif_thread_id".format(z), key="ZOOM{}_VDIF_THREAD_ID".format(z)).text = str(threadID)
                        threadID += 1
                    else:
                        ET.SubElement(vlbi_proc_param, "zoom{}_vdif_thread_id".format(z), key="ZOOM{}_VDIF_THREAD_ID".format(z)).text = "0"
                    ET.SubElement(vlbi_proc_param, "zoom{}_dest_protocol".format(z), key="ZOOM{}_VLBI_DEST_PROTOCOL".format(z)).text = "udp" # NEED TO SET
                    ET.SubElement(vlbi_proc_param, "zoom{}_dest_sequence_number".format(z), key="ZOOM{}_VLBI_DEST_SEQ_NO".format(z)).text = "1" # NEED TO SET
                    ET.SubElement(vlbi_proc_param, "zoom{}_polarisation".format(z), key="ZOOM{}_VLBI_POLARISATION".format(z)).text = "linear"

            baseband_proc_param = ET.SubElement(stream, "baseband_processing_parameters")
            ET.SubElement(baseband_proc_param, "output_nbit", key="BASEBAND_OUTNBIT").text = "4"
            ET.SubElement(baseband_proc_param, "output_encoding", key="BASEBAND_ENCODING").text = "twos_complement"

        return(root)

    def sendConfig(self, configxml):
        if DONTUSEMEDUSA:
            self.warning("WARNING NOT CONTROLLING MEDUSA")
            return()
        # Create socket to Medusa server
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.settimeout(10)
        sock.connect((self.MEDUSASERVER, self.MEDUSACOMMANDPORT))

        config = ET.tostring(configxml, encoding="ISO-8859-1")
        config += b'\r\n'

        tree = ET.ElementTree(configxml)
        tree.write("medusa-configuration.xml", encoding="ISO-8859-1", xml_declaration = True)

        self.log.info("Sending config to Medusa")
        sock.send(config)
        time.sleep(2)

        self.log.debug("Waiting for response")
        resp = sock.recv(32000)
        sock.close()

        spip_reply= ET.XML(resp)
        self.log.debug(ET.tostring(spip_reply))

# b'<tcs_response>FAIL: </tcs_response>'
# b'<tcs_response>OK</tcs_response>'

    def startstopStreaming(self,start):
        control_xml = self.controlXML(start)

        control = ET.tostring(control_xml, encoding="ISO-8859-1")
        control += b'\r\n'

        if DONTUSEMEDUSA:
            self.log.warning("WARNING NOT CONTROLLING MEDUSA")
            return()

        try:
            # Create socket to Medusa server
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.settimeout(10)
            sock.connect((self.MEDUSASERVER, self.MEDUSACOMMANDPORT))
        except socket.gaierror:
            raise RuntimeError("sock.connect failed: Unknown address {}".format(self.MEDUSASERVER))
        except ConnectionRefusedError:
            raise MedusaConnectionError("sock.connect failed to connect to {}".format(self.MEDUSASERVER))

        sock.send(control)
        time.sleep(5)
        self.log.debug("Waiting for response")
        resp = sock.recv(32000)
        sock.close()

        spip_reply= ET.XML(resp)
        self.log.debug(ET.tostring(spip_reply).decode())

        status = spip_reply.text
        self.log.debug('tcs_response={}'.format(status))
        if (status!='OK'):
            self.log.error('Medusa returned status: {}'.format(status))
            # TODO Need to stop experiment, raise exception or return code

        if start:  # Need to check error conditions above
            self.streaming = True
        else:
            self.streaming = False

    def startStreaming(self):
        self.log.info("Start Medusa Streaming")
        self.startstopStreaming(True)

    def stopStreaming(self):
        self.log.info("Stop Medusa Streaming")
        self.startstopStreaming(False)

    def controlXML(self,start):
        nStream = self.nStream()
        root = ET.Element("obs_cmd")

        command = ET.SubElement(root, "command")
        if (start):
            command.text = "start"
        else:
            command.text = "stop"

        beam_config = ET.SubElement(root, "beam_configuration")
        ET.SubElement(beam_config, "nbeam", key="NBEAM").text = "1"
        ET.SubElement(beam_config, "beam_state_0", key="BEAM_STATE_0", name="1").text = "1"

        stream_config = ET.SubElement(root, "stream_configuration")
        ET.SubElement(stream_config, "nstream", key="NSTREAM").text = str(nStream)
        for i in range(nStream):
            ET.SubElement(stream_config, "active", key="STREAM_ACTIVE").text = "1"

        obs_parm = ET.SubElement(root, "observation_parameters")
        if start:
            ET.SubElement(obs_parm, "utc_start", key="UTC_START").text = "None"
        else:
            ET.SubElement(obs_parm, "utc_stop", key="UTC_STOP").text = "None"

        return(root)


if __name__ == "__main__":

#    import numpy as np
#    from astropy.time import Time
#    from astropy.time import TimeDelta
#    import time, sys, signal, argparse, os, socket, re
#    import argparse

#    parser = argparse.ArgumentParser()
#    parser.add_argument('-a', '--antenna', '-antenna', help="Antenna ID to use")
#    parser.add_argument('-u', '-update', '--update', help="Monitoring Update", type=int, default=10)
#    parser.add_argument('-p', '-port', '--port', help="Control Port", type=int, default=2620)
#    parser.add_argument('-host', '--host', help="Jive5ab Host", default='localhost')
#    parser.add_argument('schedule', help="Recorder Schedule")
#    args = parser.parse_args()
    medusa = Medusa()
    print("Hello  World")
