#!/usr/bin/env python3

import xml.etree.ElementTree as ET
import socket
import argparse, sys, re, time, os


MEDUSASTATUSPORT = 21200
MEDUSACOMMANDPORT = 21100
MEDUSASERVER = 'medusa-srv0.atnf.csiro.au'

# Need to ensure these are matched to defaults in uwlCalSetup
defaultCalConfig = dict( frequency=0.2, dutyCycle=50, phase=0, startEdge='Rising', 
                      startBat='0x122b5790caa8a3', calEpoch='2020-12-09-04:23:55.141475',
                      tsysAvgTime=5, tsysFreqRes=1.0)

calSetup = os.getenv('LBAUWB-CAL','/home/pksobs/lba/lbauwb.cal')

parser = argparse.ArgumentParser()
parser.add_argument('-b', '--bits', '-bits', help="Number of bits", type=int, default=2)
parser.add_argument('-p', '--port', '-port', help="UDP port number", type=int, default=10000)
parser.add_argument('-d', '--debug', '-debug', help="Debug mode - do not send config to Medusa", action="store_true")
parser.add_argument('-spip', '--spip', help="Read SPIP XML from file, not socket. Enables debug mode")
parser.add_argument('--cal', '-cal', help="UWB Backend Cal settings")
parser.add_argument('--nocal', '-nocal', help="Disable Cal signal", action="store_true")
parser.add_argument('config', help="Experiment XML")
args = parser.parse_args()

# From stack overflow
#     https://stackoverflow.com/questions/28813876/how-do-i-get-pythons-elementtree-to-pretty-print-to-an-xml-file

def _pretty_print(current, parent=None, index=-1, depth=0):
    for i, node in enumerate(current):
        _pretty_print(node, current, i, depth + 1)
    if parent is not None:
        if index == 0:
            parent.text = '\n' + ('  ' * depth)
        else:
            parent[index - 1].tail = '\n' + ('  ' * depth)
        if index == len(parent) - 1:
            current.tail = '\n' + ('  ' * (depth - 1))
def getElem(parent, elem):
    child = parent.find(elem)
    if child is None:
        sys.exit("Error: Could not find <{}>".format(elem))
    return child.text

if args.cal is not None:
    calSetup = args.cal

if not os.path.isfile(calSetup):
    print("Error: '{}' does not exist".format(calSetup))

commentRE = re.compile(r'^([^#]*)#(.*)$')
    
with open(calSetup, 'r') as f:
    for line in f:
        #print(line, end='')
        m = commentRE.match(line)
        if m:
            line = m.group(1)
        # Skip blank lines
        if line.isspace() or line=="":
            continue
        ll = line.split('=')
        if len(ll)!=2:
            print("Error: Failed to parse '{}'".format(line))
            sys.exit(1)
        key = ll[0].strip()
        val = ll[1].strip()
        if not key in defaultCalConfig:
            print("Error: {} not in know calibration settings".format(key))
            sys.exit(1)
        print("Setting cal '{}' to '{}'".format(key,val))
        defaultCalConfig[key] = val  # This is a string now
        
debug = args.debug
#debug = True
if args.spip is not None:
    debug = True

    # Read Medusa setup from local file (for debugging)
    tree = ET.parse(args.spip)
    spip_setup = tree.getroot()
    
else:
    # Create socket to Medusa server for status
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.connect((MEDUSASERVER, MEDUSASTATUSPORT))

    # Form XML with configuration request
    command_root = ET.Element("tcs_state_request")
    ET.SubElement(command_root, "requestor").text = "search page"
    ET.SubElement(command_root, "type").text = "state"
    command_tree = ET.ElementTree(command_root)
    command = ET.tostring(command_root, encoding="ISO-8859-1")
    command += b'\r\n'

    # Send XML to server and get response
    sock.send(command)
    time.sleep(1)
    resp = sock.recv(32000)
    sock.close()

    spip_setup = ET.XML(resp)

    if debug:
        _pretty_print(spip_setup)
        tree = ET.ElementTree(spip_setup)
        tree.write("spip.xml", encoding="ISO-8859-1", xml_declaration = True)
        #ET.dump(spip_setup)
    
beam = spip_setup.find("beam")
medusa_state = beam.attrib['state']

if not debug and not (medusa_state == 'Idle' or  medusa_state == 'Configured') :
    print("Warning: Medusa system currently \"{}\". Will not try and configure system".format(medusa_state))
    sys.exit(1)

    
instConf = spip_setup.find("instrument_configuration")
nstream = int(instConf.attrib['nstream'])

subbands = {}

ns = 0
for stream in instConf.findall('stream'):
    freq = float(getElem(stream, 'centre_frequency'))
    bw = abs(float(getElem(stream, 'bandwidth')))
    subbands[stream.attrib['id']] = [freq,bw,None,None,None,None]
    
    ns += 1

# Read Experiment XML config file
            
tree = ET.parse(args.config)
root = tree.getroot()

exper = getElem(root, 'exper')
setups = root.findall('./mode/setup')

if (len(setups)>1):
    sys.exit("Error: Only support single setup")

setup = setups[0]

config = getElem(setup, 'name')

def findSubband(freq, bandwidth):
    freq0 = freq-bandwidth/2
    freq1 = freq+bandwidth/2

    for subband, band in subbands.items():
        f = band[0]
        bw = band[1]
        f0 = f-bw/2.0
        f1 = f+bw/2.0

        # Does zoom sit within this subband
        if (freq0>=f0 and freq1<=f1):
            if band[2] is None:
                band[2] = freq
                band[3] = bandwidth
            elif band[4] is None:
                band[4] = freq
                band[5] = bandwidth
            else:
                print("Error: Maximum 2 zooms per subband possible")
                return(None)
            return(subband)
        if (freq0>f0 and freq0<f1) or (freq1>f0 and freq1<f1):  # Straddles subband edge, but not fully enclosed
            print("Warning  Zoom {}:{} straddles Subband {} edge".format(freq,bandwidth,subband))
    return(None)

for subband in setup.findall('zoom'):
    freq = float(getElem(subband, 'frequency'))
    bandwidth = float(getElem(subband, 'bandwidth'))
    pol = getElem(subband, 'polarisation')
    sideband = getElem(subband, 'sideband')
    matchSubband = findSubband(freq, bandwidth)
    if matchSubband is None:
        print("Error: Could not match zoom {}:{}".format(freq, bandwidth))
        sys.exit(1)

for s, band in subbands.items():
    print(s, ": ", band)

def SPIP_configure(subbands, exper):
    subs = list(subbands.keys())
    subs.sort(key=int) # Make sure in order

    nStream = len(subs)

    root = ET.Element("obs_cmd")

    command = ET.SubElement(root, "command")
    command.text = "configure"

    beam_config = ET.SubElement(root, "beam_configuration")
    ET.SubElement(beam_config, "nbeam", key="NBEAM").text = "1"
    ET.SubElement(beam_config, "beam_state_0", key="BEAM_STATE_0", name="1").text = "1"

    stream_config = ET.SubElement(root, "stream_configuration")
    ET.SubElement(stream_config, "nstream", key="NSTREAM").text = str(nStream)
    for i in range(nStream):
        ET.SubElement(stream_config, "active", key="STREAM_ACTIVE").text = "1"

    source = ET.SubElement(root, "source_parameters")
    ET.SubElement(source, "name", epoch="J2000", key="SOURCE").text = "VLBI"
    ET.SubElement(source, "ra", key="RA", units="hh:mm:ss").text = "00:00:00"
    ET.SubElement(source, "dec", key="DEC", units="dd:mm:ss").text = "00:00:00"

    obs_parm = ET.SubElement(root, "observation_parameters")
    ET.SubElement(obs_parm, "observer", key="OBSERVER").text = "VLBI"
    ET.SubElement(obs_parm, "project_id", key="PID").text = exper
    ET.SubElement(obs_parm, "tobs", key="TOBS").text = "60"
    ET.SubElement(obs_parm, "utc_start", key="UTC_START").text = "None"
    ET.SubElement(obs_parm, "utc_stop", key="UTC_STOP").text = "None"

    cal_parm = ET.SubElement(root, "calibration_parameters")
    if args.nocal:
        ET.SubElement(cal_parm, "signal", key="CAL_SIGNAL").text = "0"
    else:
        ET.SubElement(cal_parm, "signal", key="CAL_SIGNAL").text = "1"
    ET.SubElement(cal_parm, "freq", key="CAL_FREQ", units="Hertz").text = str(defaultCalConfig['frequency'])
    ET.SubElement(cal_parm, "phase", key="CAL_PHASE").text = str(defaultCalConfig['phase'])
    ET.SubElement(cal_parm, "duty_cycle", key="CAL_DUTY_CYCLE").text = "{:.2f}".format(float(defaultCalConfig['dutyCycle'])/100.0)
    ET.SubElement(cal_parm, "epoch", key="CAL_EPOCH", units="YYYY-DD-MM-HH:MM:SS+0").text = str(defaultCalConfig['calEpoch'])
    ET.SubElement(cal_parm, "tsys_avg_time", key="TSYS_AVG_TIME", units="seconds").text = str(defaultCalConfig['tsysAvgTime'])
    ET.SubElement(cal_parm, "tsys_freq_resolution", key="TSYS_FREQ_RES", units="MHz").text = str(defaultCalConfig['tsysFreqRes'])

    for s in subs:
        stream = ET.SubElement(root, "stream{}".format(s))

        custom_param = ET.SubElement(stream, "custom_parameters")
        ET.SubElement(custom_param, "adaptive_filter", key="ADAPTIVE_FILTER").text = "0"
        ET.SubElement(custom_param, "adaptive_filter_epsilon", key="ADAPTIVE_FILTER_EPSILON").text = "0.1"
        ET.SubElement(custom_param, "adaptive_filter_nchan", key="ADAPTIVE_FILTER_NCHAN").text = "128"
        ET.SubElement(custom_param, "adaptive_filter_nsamp", key="ADAPTIVE_FILTER_NSAMP").text = "1024"
        ET.SubElement(custom_param, "schedule_block_id", key="SCHED_BLOCK_ID").text = "0"
        ET.SubElement(custom_param, "scan_id", key="SCAN_ID").text = "0"
        ET.SubElement(custom_param, "raw_baseband", key="RECORD_RAW_BASEBAND").text = "0"

        processing_modes = ET.SubElement(stream, "processing_modes")
        ET.SubElement(processing_modes, "fold", key="PERFORM_FOLD").text = "0"
        ET.SubElement(processing_modes, "search", key="PERFORM_SEARCH").text = "0"
        ET.SubElement(processing_modes, "continuum", key="PERFORM_CONTINUUM").text = "0"
        ET.SubElement(processing_modes, "spectral_line", key="PERFORM_SPECTRAL_LINE").text = "0"

        if subbands[s][2] is not None:
            ET.SubElement(processing_modes, "vlbi", key="PERFORM_VLBI").text = "1"
        else:
            ET.SubElement(processing_modes, "vlbi", key="PERFORM_VLBI").text = "0"

        ET.SubElement(processing_modes, "baseband", key="PERFORM_BASEBAND").text = "0"

        zoom_proc_param = ET.SubElement(stream, "zoom_processing_parameters")

        def createZoom(zoom_proc_param, zoomID, active, bandwidth, freq, mode):
            ET.SubElement(zoom_proc_param, "active{:d}".format(zoomID), key="ZOOM{:d}_ACTIVE".format(zoomID)).text = str(active)
            ET.SubElement(zoom_proc_param, "bandwidth{:d}".format(zoomID), key="ZOOM{:d}_BW".format(zoomID), units="MHz").text = "{:.0f}".format(bandwidth)
            ET.SubElement(zoom_proc_param, "frequency{:d}".format(zoomID), key="ZOOM{:d}_FREQUENCY".format(zoomID), units="MHz").text = "{:.0f}".format(freq)
            ET.SubElement(zoom_proc_param, "mode{:d}".format(zoomID), key="ZOOM{:d}_MODE".format(zoomID)).text = str(mode)
    
        if subbands[s][2] is not None:
            createZoom(zoom_proc_param, 1, 1, subbands[s][3], subbands[s][2], "vlbi")
        else:
            createZoom(zoom_proc_param, 1, 0, 4, 832, "baseband")        

        if subbands[s][4] is not None:
            createZoom(zoom_proc_param, 2, 1, subbands[s][5], subbands[s][4], "vlbi")
        else:
            createZoom(zoom_proc_param, 2, 0, 4, 832, "baseband")        

        vlbi_proc_param = ET.SubElement(stream, "vlbi_processing_parameters")
        ET.SubElement(vlbi_proc_param, "output_nbit", key="VLBI_OUTNBIT").text = str(args.bits)
        ET.SubElement(vlbi_proc_param, "output_encoding", key="VLBI_ENCODING").text = "offset_binary"
        ET.SubElement(vlbi_proc_param, "dest_ip", key="VLBI_DEST_IP").text = "10.17.10.1"
        ET.SubElement(vlbi_proc_param, "dest_port", key="VLBI_DEST_PORT").text = str(args.port)
        ET.SubElement(vlbi_proc_param, "vdif_thread_id", key="VLBI_VDIF_THREAD_ID").text = "0"
        ET.SubElement(vlbi_proc_param, "dest_protocol", key="VLBI_DEST_PROTOCOL").text = "udp"
        ET.SubElement(vlbi_proc_param, "dest_sequence_number", key="VLBI_DEST_SEQ_NO").text = "1"

        baseband_proc_param = ET.SubElement(stream, "baseband_processing_parameters")
        ET.SubElement(baseband_proc_param, "output_nbit", key="BASEBAND_OUTNBIT").text = "4"
        ET.SubElement(baseband_proc_param, "output_encoding", key="BASEBAND_ENCODING").text = "twos_complement"

    return(root)

# Enable (or not) Cal signal

if False and not debug:
    if (args.nocal):
        ret = os.system('uwlCalSetup --no-med calOff')
    else:
        ret = os.system('uwlCalSetup --no-med -f {} calon'.format(calSetup))

    if ret!=0:
        print("uwlCalSetup failed to run ({})".format(os.WEXITSTATUS(ret)))
        sys.exit(1)

config_xml = SPIP_configure(subbands, exper)

if (debug):
    _pretty_print(config_xml)
    tree = ET.ElementTree(config_xml)
    tree.write("configuration.xml", encoding="ISO-8859-1", xml_declaration = True)
else:
    # Send configuration XML to server
    print("Sending config XML to Medusa")

    # Create socket to Medusa server
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.connect((MEDUSASERVER, MEDUSACOMMANDPORT))
    
    config = ET.tostring(config_xml, encoding="ISO-8859-1")
    config += b'\r\n'
    
    sock.send(config)
    time.sleep(2)

    print("Waiting for response")
    resp = sock.recv(32000)
    sock.close()

    spip_reply= ET.XML(resp)
    _pretty_print(spip_reply)
    print(ET.tostring(spip_reply))



    


