#!/usr/bin/env python3

import socket, re, logging, sys

class JIVE5AB:
    'Connection to JIVE5AB server'
    import re

    def __init__(self):
        self.sock = None
        self.recording = False
        self.usingDatastream = False
        self.log = logging.getLogger(__name__)

    def __del__(self):
        self.log.info('*****Deleting J5 object*****')

        if self.sock is not None:
            if self.recording:
                self.recordingOff()

            if self.usingDatastream:
                self.resetDatastream()

            self.sock.close()

    # open socket to ip:port
    def open(self, ip, port, timeout=None):
        assert self.sock is None, "Socket already open"
        self.sock = socket.socket()
        if timeout is not None:
            self.sock.settimeout(timeout)
        self.sock.connect((ip, port))

    def cmd(self, cmd):
        assert self.sock is not None, "Socket not open"
        try:
            self.log.debug("Sending \"{}\"".format(cmd))
            self.sock.sendall(cmd.encode())
            rv = self.sock.recv(1024).decode()
        except OSError as e:
            self.log.error("Reading j5 socket: {}".format(str(e)))
            return(None) # Maybe should raise error, close socket etc
        except Exception as e:
            self.log.critical("Unexpected error reading j5 socket: {}".format(str(e)))
            raise 
        self.log.debug("  : {}".format(rv.strip("\n\r")))
        return rv.strip("\n\r")

    def mtu(self, mtuval=None):
        if mtuval is None:
            thiscmd = 'mtu?'
        else:
            thiscmd = 'mtu={}'.format(mtuval)
        self.cmd(thiscmd)

    def net_port(self, port=None):
        if port is None:
            thiscmd = 'net_port?'
        else:
            thiscmd = 'net_port={}'.format(port)
        ret = self.cmd(thiscmd)
        if ret is not None:
            self.log.info(ret)

    def net_protocol(self, protocol=None, socbuf=None, workbuf=None, nbuf=None):
        if protocol is None:
            thiscmd = 'net_protocol?'
        else:
            thiscmd = 'net_protocol={}'.format(protocol)

            # Work backwards to make sure if "later" parameters are set, earlier ones are not "None"
            if nbuf is not None:
                if workbuf is None: workbuf = ''
            if workbuf is not None:
                if socbuf is None: socbuf = ''

            # Now work back along parameters, building up command if need be
            if socbuf is not None:
                thiscmd += " : {}".format(socbuf)
                if workbuf is not None:
                    thiscmd += " : {}".format(workbuf)
                    if nbuf is not None:
                        thiscmd += " : {}".format(nbuf)

        ret = self.cmd(thiscmd)
        if ret is not None:
            self.log.info(ret)

    def mode(self, modeval=None):
        if modeval is None:
            thiscmd = 'mode?'
        else:
            thiscmd = 'mode={}'.format(modeval)
        ret = self.cmd(thiscmd)
        if ret is not None:
            self.log.info(ret)
    def close(self):
        if self.sock is not None:
            self.sock.close()
            self.sock = None

    def recording(self):
        return self.recording

    def recordingOn(self, scan):
        ret = self.cmd("record=on:{}".format(scan))
        self.recording = True
        if ret is not None:
            self.log.debug(ret)


    def recordingOff(self):
        ret = self.cmd("record=off")
        self.recording = False
        if ret is not None:
            self.log.debug(ret)

    def tstat_query(self):
        ret = self.cmd("tstat?")
        if ret is not None:
            # !tstat?  0 : 10.02s : vbsrecord : UdpsNorReadStream 153.197Mbps : F 0.0% ;
            # !tstat? 0 : 0.0 : no_transfer ;
            return re.split(" *: *", ret)
        else:
            return None

    def status_query(self):
        ret = self.cmd("status?")
        if ret is not None:
            return re.split(" *: *", ret)
        else:
            return None

    def evlbi_query(self):
        ret = self.cmd("evlbi?")
        if ret is not None:
            # !evlbi? 0 : total : 295762 : loss : 0 ( 0.00%) : out-of-order : 0 ( 0.00%) : extent : 0seqnr/pkt ;
            return re.split(" *: *", ret)
        else:
            return None

    def rtime(self):
        ret = self.cmd("rtime?")
        if ret is not None:
            # !rtime? 0 : 9.96565e+06s : 160088GB : 44.4748% : VDIF (complex) : 8 : 0MHz : 128.512Mbps ;
            return re.split(" *: *", ret)
        else:
            return None

    def splitDatastream(self):
        self.cmd("datastream=add:{thread}:*")
        self.usingDatastream = True

    def resetDatastream(self):
        self.cmd("datastream=clear")
        self.usingDatastream = False

if __name__ == "__main__":

    from astropy.time import Time
    from astropy.time import TimeDelta
    import time, argparse, os

    parser = argparse.ArgumentParser()
    parser.add_argument('-a', '--antenna', '-antenna', help="Antenna ID to use")
    parser.add_argument('-u', '-update', '--update', help="Monitoring Update", type=int, default=10)
    parser.add_argument('-p', '-port', '--port', help="Control Port", type=int, default=2620)
    parser.add_argument('-host', '--host', help="Jive5ab Host", default='localhost')
    parser.add_argument('schedule', help="Recorder Schedule")
    args = parser.parse_args()

    antenna = args.antenna
    if antenna is None:
        antenna = os.environ.get('JIVE5AB_ANTID')

    setup = {'net_port': None,
             'mtu': 9000,
             'net_protocol': None,
             'mode': None,
             'exper': None,
             'antenna': antenna,
             'port': args.port,
             'host': args.host
    }

    fileType = 'vdif'

    update = args.update
    update = TimeDelta(update, format='sec')

    now = Time.now()
    year = now.datetime.year
    currentMJD = now.mjd

#    def signal_handler(sig, frame):
#        print("You pressed ^C!")
#        if recording:
#            j5.recordOff()
#            # Should reset datastream
#        sys.exit(0)

    def parseSetup(key, tokens):
        if key.lower() == tokens[0].lower():
            if len(tokens) != 2:
                raise ValueError('Error parsing {}'.format(line))
            setup[key] = tokens[1]
            return True
        else:
            return False

    def parseTime(thisTime):
        t = '{}:{}'.format(year,thisTime.replace('/',':'))
        return Time(t, format='yday', scale='utc')

    def parseScanTime(start, end):
        startTime = parseTime(start)
        endTime = parseTime(end)
        return(startTime, endTime)

    SCANNAME = 0
    SCANSTART = 1
    SCANEND = 2
    scanList = []

    count = 0
    with open(args.schedule) as f:
        for line in f:
            line = line.split('#')[0] # Remove comments
            line = line.strip()       # Remove trailing and leading spaces
            if len(line)==0: continue

            count += 1

            tokens = line.split()

            # Check for setup info
            parsed = False
            for key in setup.keys():
                if parseSetup(key, tokens):
                    parsed = True
                    break
            if not parsed:
                if len(tokens) != 3:
                    raise ValueError('Error parsing {}'.format(line))
                scan = tokens[0]
                (start, end) = parseScanTime(tokens[1], tokens[2])
                scanList.append((scan, start, end))

    # Sort scan list by start time
    scanList.sort(key=lambda x: x[SCANSTART])

    # Check individual scans OK
    for s in scanList:
        if (abs((s[SCANEND]-s[SCANSTART]).value*24*60*60)) < 1:
            raise AssertionError("Scan {} too short ({}-{})".format(s[0], s[SCANSTART], s[SCANEND]))
        if s[SCANSTART]>s[SCANEND]:
            raise AssertionError("Scan {} stops before it starts! ({}-{})".format(s[0], s[SCANSTART], s[SCANEND]))

    # Check no overlap between scans
    for i in range(len(scanList)-1):
        if scanList[i][SCANEND] > scanList[i+1][SCANSTART]:
            raise AssertionError("Scan {} overlaps with {}".format(scanList[i][0], scanList[i+1][0]))

    # Check for missing setup info
    #pass

    # Tweak types
    setup['port'] = int(setup['port'])

    if now > scanList[-1][SCANEND]:
        print("Observation already finished")
        sys.exit(1)

    if now > scanList[0][SCANSTART]:
        # Need to skip some scans
        startScan = None
        for i, s in enumerate(scanList):
            if s[SCANEND] < now:
                print("Skipping scan {}".format(s[SCANNAME]))
            elif s[SCANSTART] > now:
                startScan = i
                break
            elif s[SCANEND] > now:
                # Current scan already started
                print("Scan {} already started".format(s[SCANNAME]))
                startScan = i  # No need to modify start time (but keep that in mind below)
                break
        if startScan is not None:
            scanList = scanList[startScan:]
            print("Starting at scan {}".format(scanList[0][SCANNAME]))
        else:
            print("Something went wrong - no scans found")
            sys.exit(1)

#    for s in scanList:
#        print("{} {} {}".format(s[SCANNAME], s[SCANSTART], s[SCANEND]))

    print("Connecting to {}:{}".format(setup['host'],setup['port']))
    j5 = JIVE5AB()

    try:
        j5.open(setup['host'], setup['port'])
    except ConnectionRefusedError as e:
        print("Error connecting to jive5ab server: ", e)
        sys.exit()

    #signal.signal(signal.SIGINT, signal_handler)


    # Send setup info
    j5.net_port(setup['net_port'])
    j5.net_protocol(setup['net_protocol'])
    j5.mtu(setup['mtu'])
    j5.mode(setup['mode'])
    if True:
        j5.splitDatastream()

    for s in scanList:
        now = Time.now()
        if now<s[SCANSTART]:
            print("Scan {} starts in {:.1f} sec".format(s[SCANNAME], (s[SCANSTART]-now).sec))
        while now<s[SCANSTART]:
            time.sleep((s[SCANSTART]-now).sec/2.0)
            now = Time.now()

        print("    {} {} {}".format(s[SCANNAME], s[SCANSTART], s[SCANEND]))
        print("**", now)
        j5.recordingOn("{}_{}_{}.{}".format(setup['exper'],setup['antenna'],s[SCANNAME],fileType))
        j5.tstat_query() # Resets counters
        lastUpdate = now
        while now<s[SCANEND]:
            if (now-lastUpdate) > update:
                lastUpdate = now
                tstat = j5.tstat_query()
                #print("***** ", tstat)
                if "Retry - we're initialized now" in tstat[1]:
                    print("tstat initialising")
                else:
                    print("tstat {} {} {}".format(tstat[1],tstat[3], tstat[4]))
                evlbi = j5.evlbi_query()
                print(evlbi)

            time.sleep(1)
            now = Time.now()
        j5.recordingOff()
        print("   **", now)

    j5.close()
