#!/usr/bin/env python3

import argparse, time
from astropy.time import Time, TimeDelta
from math import radians, degrees
from legacyserver import LegacyServer
from enum import Enum, auto

server = 'bigrock.atnf.csiro.au'
port = 2334

parser = argparse.ArgumentParser()
parser.add_argument('-fake', '--fake', help="Fake tracking", action="store_true")
parser.add_argument('-s', '--stow', help="Stow the telescope when done", action="store_true")
parser.add_argument('schedule')

args = parser.parse_args()

filename = args.schedule

sched = []

class Wrap(Enum):
    CLOSEST = auto()
    NORTH = auto()
    SOUTH = auto()

    def asStr(self):
        return self.name

def convert_to_wrap(self, wrap_value):
    if isinstance(wrap_value, Wrap):
        return wrap_value
    elif isinstance(wrap_value, str):
        if wrap_value=='CLOSEST':
            return Wrap.CLOSEST
        elif wrap_value=='NORTH':
            return Wrap.NORTH
        elif wrap_value=='SOUTH':
            return Wrap.SOUTH
        else:
            raise ValueError("Wrap must be CLOSEST, NORTH or SOUTH")
    else:
        raise ValueError("Wrap must be either a Wrap object or a string representing a wrap")

def mjd2bat(t):
    '''
    Convert Binary Atomic Time (BAT) to UTC.  At the ATNF, BAT corresponds to
    the number of microseconds of atomic clock since MJD (1858-11-17 00:00).
    TAI is exactly 37 seconds ahead of UTC
    '''
    dutc = 37.
    bat = int(round((t*86400. + dutc) *1000000))
    return bat
    
def loadtrack(s, wrap, time1, az1, el1, time2, az2, el2):
    print("Load {} ({:.2f},{:.2f}) to {} ({:.2f},{:.2f})".format(time1, degrees(az1), degrees(el1), time2, degrees(az2), degrees(el2)))

    interval = (time2.mjd - time1.mjd)*24*60*60

    args = "{:s} AzElApp {:f} {:f} {:f} {:f} {:f} {:f}".format(wrap.asStr(), az1, az2, (az2-az1)/interval, el1, el2, (el2-el1)/interval)
    s.send_cmd('sscan', args, when=mjd2bat(time1.mjd))

def printMon(mon):
    print(f"\rServo state = {mon.servo_state:^10}  AzEl=({mon.azdeg:.1f},{mon.eldeg:.1f}) Error=({mon.delta_az_deg:.2f},{mon.delta_el_deg:.2f})     ", end="")
    
wrap = Wrap.CLOSEST

with open(filename, "r") as f:
  for l in f:
    line = l.strip()

    # Remove any comments and strip whitespace
    line = line.split('#', 1)[0].strip()
    if not line:
      continue  # Skip blank lines
    
    # allow setting of which wrap to use. Default is CLOSEST. Can be
    # set to NORTH or SOUTH by having a line in the schedule that says
    # wrap NORTH

    if line.split()[0].lower() == "wrap":
        wrap = line.split()[1].upper()
        wrap = convert_to_wrap(wrap)
        print("Using {} wrap".format(wrap.asStr()))
        continue
    else:
        (timestr, Az, El) = line.split()[:3]

    t = Time(timestr, format='isot', scale='utc')
    Az = radians(float(Az))
    El = radians(float(El))
    
    sched.append((t, Az, El))

# Should we wait
startTime = sched[0][0] - TimeDelta(15*60, format='sec')
if Time.now()<startTime:
    print("Schedule does not start till {}".format(sched[0][0].iso))
    print("Waiting till {}".format(startTime.iso))
    while Time.now()<startTime:
        time.sleep(5)
    print("Starting observations")
    
# Connect to antenna
s = LegacyServer((server,port), fake=args.fake)
s.connect()
s.allocate()
s.send_cmd('enable')
s.send_cmd('drvOn')

now = Time.now()

if (sched[0][0]<now):
    print("Schedule already started - skipping ahead")
    sched.pop(0)
    while len(sched)>0:
        if sched[0][0]>now: break
        sched.pop(0)

if len(sched)==0:
    print("Schedule already finished")
    exit(1)

print("Next scan at {}".format(sched[0][0]))
# Go to first Az/El
s.goto(wrap.asStr(), 'AzElApp', sched[0][1], sched[0][2])
time.sleep(5) 

mon = s.get_mon()
while mon.servo_state == 'SLEWING':
    printMon(mon)
    time.sleep(2)
    mon = s.get_mon()

if (sched[0][0]<now):
    print("Drat, missed start - skip some more")
    sched.pop(0)
    while len(sched)>0:
        if sched[0][0]>now: break
        sched.pop(0)

if len(sched)==0:
    print("Schedule finished")
    exit(1)

# Load the first 5 tracks
loadtimes = []
i = 0
print()
while i<5 and len(sched)>1:
    loadtrack(s, wrap, sched[0][0], sched[0][1], sched[0][2], sched[1][0], sched[1][1], sched[1][2])
    i += 1
    sched.pop(0)
    loadtimes.append(sched[0][0])

while len(sched)>1:
    # Wait till next load
    while Time.now()<loadtimes[0]:
        mon = s.get_mon()
        printMon(mon)
        time.sleep(1)
    print()
    loadtrack(s, wrap, sched[0][0], sched[0][1], sched[0][2], sched[1][0], sched[1][1], sched[1][2])
    i += 1
    sched.pop(0)
    loadtimes.pop(0)
    loadtimes.append(sched[0][0])

lastTime = loadtimes[-1]

while Time.now()<lastTime:
    mon = s.get_mon()
    printMon(mon)
    time.sleep(1)

print()

if args.stow:
    time.sleep(10)
    print("Stowing antenna")
    s.stow()

s.deallocate()    
