#!/usr/bin/env python3

import xml.etree.ElementTree as ET
import argparse, sys, re

parser = argparse.ArgumentParser()
parser.add_argument('-b', '--bits', '-bits', help="Number of bits", type=int, default=2)
parser.add_argument('-p', '--port', '-port', help="UDP port number", type=int, default=10000)
#parser.add_argument('-N', '--nchan', '-nchan', help="Number of voltage channels (IFs)", type=int, default=1)
#parser.add_argument('-f', '--nfft', '-nfft', help="Number of FFT to average", type=int, default=100)
#parser.add_argument('-c', '--complex', '-complex', help="Assume complex voltages", action="store_true")
parser.add_argument('config', help="Experiment XML")
parser.add_argument('spip', help="SPIP confg")
args = parser.parse_args()

# Read spip config file to get the subband sky frequencies and bandwidth

# From stack overflow
#     https://stackoverflow.com/questions/28813876/how-do-i-get-pythons-elementtree-to-pretty-print-to-an-xml-file

def _pretty_print(current, parent=None, index=-1, depth=0):
    for i, node in enumerate(current):
        _pretty_print(node, current, i, depth + 1)
    if parent is not None:
        if index == 0:
            parent.text = '\n' + ('  ' * depth)
        else:
            parent[index - 1].tail = '\n' + ('  ' * depth)
        if index == len(parent) - 1:
            current.tail = '\n' + ('  ' * (depth - 1))

subbandRE = re.compile(r"SUBBAND_CONFIG_(\d+)\s+(\d+):(\d+):(\d+)")

subbands = {}

with open(args.spip) as f:
    for line in f:
        m = subbandRE.search(line)
        if m != None:
            subbands[m.group(1)] = [float(m.group(2)),float(m.group(3)),None,None,None,None]

# Read Experiment XML config file

def getElem(parent, elem):
    child = parent.find(elem)
    if child is None:
        sys.exit("Error: Could not find <{}>".format(elem))
    return child.text
            
tree = ET.parse(args.config)
root = tree.getroot()

exper = getElem(root, 'exper')
print("Exper = ", exper)

setups = root.findall('./mode/setup')

if (len(setups)>1):
    sys.exit("Error: Only support single setup")

setup = setups[0]

config = getElem(setup, 'name')
print("Mode ", config)

def findSubband(freq, bandwidth):
    freq0 = freq-bandwidth/2
    freq1 = freq+bandwidth/2
    for subband, band in subbands.items():
        f = band[0]
        bw = band[1]
        f0 = f-bw/2.0
        f1 = f+bw/2.0
        # Does zoom sit within this subband
        if (freq0>f0 and freq1<f1):
            print("Matched {}:{} to Subband {}".format(f,bw,subband))
            if band[2] is None:
                band[2] = freq
                band[3] = bandwidth
            elif band[4] is None:
                band[4] = freq
                band[5] = bandwidth
            else:
                print("Error: Maximum 2 zooms per subband possible")
                return(None)
            return(subband)
        if (freq0>f0 and freq0<f1) or (freq1>f0 and freq1<f1):  # Straddles subband edge, but not fully enclosed
            print("Warning  Zoom {}:{} straddles Subband {} edge".format(freq,bandwidth,subband))
    return(None)

for subband in setup.findall('zoom'):
    freq = float(getElem(subband, 'frequency'))
    bandwidth = float(getElem(subband, 'bandwidth'))
    pol = getElem(subband, 'polarisation')
    sideband = getElem(subband, 'sideband')
    print("** Zoom:  {}, {}, {}, {}".format(freq, bandwidth, pol, sideband))
    matchSubband = findSubband(freq, bandwidth)
    if matchSubband is None:
        print("Error: Could not match zoom {}:{}".format(freq, bandwidth))

for s, band in subbands.items():
    print(s, ": ", band)


def SPIP_configure(subbands, exper, calFreq, calEpoch, calAvg):
    subs = list(subbands.keys())
    subs.sort(key=int) # Make sure in order

    nStream = len(subs)

    root = ET.Element("obs_cmd")

    command = ET.SubElement(root, "command")
    command.text = "configure"

    beam_config = ET.SubElement(root, "beam_configuration")
    ET.SubElement(beam_config, "nbeam", key="NBEAM").text = "1"
    ET.SubElement(beam_config, "beam_state_0", key="BEAM_STATE_0", name="1").text = "1"

    stream_config = ET.SubElement(root, "stream_configuration")
    ET.SubElement(stream_config, "nstream", key="NSTREAM").text = str(nStream)
    for i in range(nStream):
        ET.SubElement(stream_config, "active", key="STREAM_ACTIVE").text = "1"

    source = ET.SubElement(root, "source_parameters")
    ET.SubElement(source, "name", key="SOURCE", epoch="J2000").text = "VLBI"
    ET.SubElement(source, "ra", key="RA", units="hh:mm:ss").text = "00:00:00"
    ET.SubElement(source, "dec", key="DEC", units="dd:mm:ss").text = "00:00:00"

    obs_parm = ET.SubElement(root, "observation_parameters")
    ET.SubElement(obs_parm, "observer", key="OBSERVER").text = "VLBI"
    ET.SubElement(obs_parm, "project_id", key="PID").text = exper
    ET.SubElement(obs_parm, "utc_start", key="UTC_START").text = "None"
    ET.SubElement(obs_parm, "utc_stop", key="UTC_STOP").text = "None"

    cal_parm = ET.SubElement(root, "calibration_parameters")
    ET.SubElement(cal_parm, "signal", key="CAL_SIGNAL").text = "1"
    ET.SubElement(cal_parm, "freq", key="CAL_FREQ", units="Hertz").text = str(calFreq)
    ET.SubElement(cal_parm, "phase", key="CAL_PHASE").text = "0.0"
    ET.SubElement(cal_parm, "duty_cycle", key="CAL_DUTY_CYCLE").text = "0.5"
    ET.SubElement(cal_parm, "epoch", key="CAL_EPOCH", units="YYYY-DD-MM-HH:MM:SS+0").text = calEpoch
    ET.SubElement(cal_parm, "tsys_avg_time", key="TSYS_AVG_TIME", units="seconds").text = str(calAvg)
    ET.SubElement(cal_parm, "tsys_freq_resolution", key="TSYS_FREQ_RES", units="MHz").text = "1"

    for s in subs:
        stream = ET.SubElement(root, "stream{}".format(s))

        custom_param = ET.SubElement(stream, "custom_parameters")
        ET.SubElement(custom_param, "adaptive_filter", key="ADAPTIVE_FILTER").text = "0"
        ET.SubElement(custom_param, "adaptive_filter_epsilon", key="ADAPTIVE_FILTER_EPSILON").text = "0.1"
        ET.SubElement(custom_param, "adaptive_filter_nchan", key="ADAPTIVE_FILTER_NCHAN").text = "128"
        ET.SubElement(custom_param, "adaptive_filter_nsamp", key="ADAPTIVE_FILTER_NSAMP").text = "1024"
        ET.SubElement(custom_param, "schedule_block_id", key="SCHED_BLOCK_ID").text = "0"
        ET.SubElement(custom_param, "scan_id", key="SCAN_ID").text = "0"
        ET.SubElement(custom_param, "raw_baseband", key="RECORD_RAW_BASEBAND").text = "0"

        processing_modes = ET.SubElement(stream, "processing_modes")
        ET.SubElement(processing_modes, "fold", key="PERFORM_FOLD").text = "0"
        ET.SubElement(processing_modes, "search", key="PERFORM_SEARCH").text = "0"
        ET.SubElement(processing_modes, "continuum", key="PERFORM_CONTINUUM").text = "0"
        ET.SubElement(processing_modes, "spectral_line", key="PERFORM_SPECTRAL_LINE").text = "0"

        if subbands[s][2] is not None:
            ET.SubElement(processing_modes, "vlbi", key="PERFORM_VLBI").text = "1"
        else:
            ET.SubElement(processing_modes, "vlbi", key="PERFORM_VLBI").text = "0"

        ET.SubElement(processing_modes, "baseband", key="PERFORM_BASEBAND").text = "0"

        zoom_proc_param = ET.SubElement(stream, "zoom_processing_parameters")

        def createZoom(zoom_proc_param, zoomID, active, bandwidth, freq, mode):
            ET.SubElement(zoom_proc_param, "active{:d}".format(zoomID), key="ZOOM{:d}_ACTIVE".format(zoomID)).text = str(active)
            ET.SubElement(zoom_proc_param, "bandwidth{:d}".format(zoomID), key="ZOOM{:d}_BW".format(zoomID), units="MHz").text = str(bandwidth)
            ET.SubElement(zoom_proc_param, "frequency{:d}".format(zoomID), key="ZOOM{:d}_FREQUENCY".format(zoomID), units="MHz").text = str(freq)
            ET.SubElement(zoom_proc_param, "mode{:d}".format(zoomID), key="ZOOM{:d}_MODE".format(zoomID)).text = str(mode)
    
        if subbands[s][2] is not None:
            createZoom(zoom_proc_param, 1, 1, subbands[s][3], subbands[s][2], "vlbi")
        else:
            createZoom(zoom_proc_param, 1, 0, 4, 832, "baseband")        

        if subbands[s][4] is not None:
            createZoom(zoom_proc_param, 2, 1, subbands[s][5], subbands[s][4], "vlbi")
        else:
            createZoom(zoom_proc_param, 2, 0, 4, 832, "baseband")        

#        ET.SubElement(zoom_proc_param, "active1", key="ZOOM1_ACTIVE").text = "0"
#        ET.SubElement(zoom_proc_param, "bandwidth1", key="ZOOM1_BW", units="MHz").text = "4"
#        ET.SubElement(zoom_proc_param, "frequency1", key="ZOOM1_FREQUENCY", units="MHz").text = "832"
#        ET.SubElement(zoom_proc_param, "mode1", key="ZOOM1_MODE").text = "baseband"

        
#    ET.SubElement(zoom_proc_param, "active2", key="ZOOM2_ACTIVE").text = "0"
#    ET.SubElement(zoom_proc_param, "bandwidth2", key="ZOOM2_BW", units="MHz").text = "4"
#    ET.SubElement(zoom_proc_param, "frequency2", key="ZOOM2_FREQUENCY", units="MHz").text = "832"
#    ET.SubElement(zoom_proc_param, "mode2", key="ZOOM2_MODE").text = "baseband"

        vlbi_proc_param = ET.SubElement(stream, "vlbi_processing_parameters")
        ET.SubElement(vlbi_proc_param, "output_nbit", key="VLBI_OUTNBIT").text = str(args.bits)
        ET.SubElement(vlbi_proc_param, "output_encoding", key="VLBI_ENCODING").text = "offset_binary"
        ET.SubElement(vlbi_proc_param, "dest_ip", key="VLBI_DEST_IP").text = "10.17.10.1"
        ET.SubElement(vlbi_proc_param, "dest_port", key="VLBI_DEST_PORT").text = str(args.port)
        ET.SubElement(vlbi_proc_param, "vdif_thread_id", key="VLBI_VDIF_THREAD_ID").text = "0"
        ET.SubElement(vlbi_proc_param, "dest_protocol", key="VLBI_DEST_PROTOCOL").text = "udp"
        ET.SubElement(vlbi_proc_param, "dest_sequence_number", key="VLBI_DEST_SEQ_NO").text = "1"

    return(root)

config_xml = SPIP_configure(subbands, exper, 100, 'SETME', 5)

_pretty_print(config_xml)
tree = ET.ElementTree(config_xml)
tree.write("configure-test.xml", encoding="ISO-8859-1", xml_declaration = True)



