#!/usr/bin/env python3

import argparse, sys, time, os, logging
from astropy.time import Time
from astropy.time import TimeDelta
from astropy import units as u
from astropy.coordinates import Angle
from skyfield.api import load, EarthSatellite, wgs84
from math import degrees, ceil
from enum import Enum, auto

from legacyserver import LegacyServer

server = 'bigrock.atnf.csiro.au'
port = 2334

parser = argparse.ArgumentParser()
parser.add_argument('-tleupdate', '--tleupdate', help="Interpolation size for TLE (sec)", type=int, default=10)
parser.add_argument('-tlequeue', '--tlequeue', help="Number of TLE segments to pre-load", type=int, default=5)
parser.add_argument('-stow', '--stow', help="Stow antenna at end of observation", action="store_true")
parser.add_argument('-debug', '--debug', help="Enable debug messages", action="store_true")
parser.add_argument('-fake', '--fake', help="Fake tracking", action="store_true")
parser.add_argument('schedule')

args = parser.parse_args()

debug = args.debug

if False and args.debug:
    logging.basicConfig(level=logging.DEBUG)
else:
    logging.basicConfig(level=logging.INFO)

class Wrap(Enum):
    CLOSEST = auto()
    NORTH = auto()
    SOUTH = auto()

    def asStr(self):
        return self.name

class Mode(Enum):
    J2000 = auto()
    TLE = auto()
    AZEL = auto()

    def asStr(self):
        return self.name

    def command(self):
        if self==Mode.J2000:
            return 'J2000Mean'
        elif self==Mode.AZEL:
            return 'AzElApp'
        else:
            return None
        
# Example usage:
#scan = Scan(name="Test Scan", mode="Test Mode", start="2024-01-01T00:00:00", stop="2024-01-01T01:00:00")
class Scan:
    def __init__(self, name: str, start, stop, left, right, mode = Mode.J2000, wrap = Wrap.CLOSEST):
        self.name = name
        self.mode = self._convert_to_mode(mode)
        self.start = self._convert_to_time(start)
        self.stop = self._convert_to_time(stop)
        self.wrap = self._convert_to_wrap(wrap)

        self.left = self._convert_to_angle(left, self.mode==Mode.J2000)
        self.right = self._convert_to_angle(right, False)
        self.tle = None
        
    def _convert_to_time(self, time_value):
        if isinstance(time_value, Time):
            return time_value
        elif isinstance(time_value, str):
            return Time(time_value)
        else:
            raise ValueError("start and stop must be either an AstroPy Time object or a string representing a time")

    def _convert_to_mode(self, mode_value):
        if isinstance(mode_value, Mode):
            return mode_value
        elif isinstance(mode_value, str):
            if mode_value=='J2000':
                return Mode.J2000
            elif mode_value=='TLE':
                return Mode.TLE
            elif mode_value=='AZEL':
                return Mode.AZEL
            else:
                raise ValueError("Mode must be J2000, AZEL or TLE")
        else:
            raise ValueError("mode must be either a Mode object or a string representing a mode")

    def _convert_to_wrap(self, wrap_value):
        if isinstance(wrap_value, Wrap):
            return wrap_value
        elif isinstance(wrap_value, str):
            if wrap_value=='CLOSEST':
                return Wrap.CLOSEST
            elif wrap_value=='NORTH':
                return Wrap.NORTH
            elif wrap_value=='SOUTH':
                return Wrap.SOUTH
            else:
                raise ValueError("Wrap must be CLOSEST, NORTH or SOUTH")
        else:
            raise ValueError("wrap must be either a Wrap object or a string representing a wrap")

    def _convert_to_angle(self, angle_value, hms=False):
        if angle_value is None: return None
        if isinstance(angle_value, Angle):
            return angle_value
        elif isinstance(angle_value, str):
            if hms:
                return Angle(angle_value, unit=u.hour)
            else:
                return Angle(angle_value, unit=u.deg)
        else:
            raise ValueError("left and right must be either an AstroPy Angle object or a string representing an angle")

class TLEpoint:
    def __init__(self, time, az, el):
        self.time = time
        self.az = az
        self.el = el
        
# Get the current time in UTC
current_time_utc = Time.now()
# Convert the current time to TAI
current_time_tai = current_time_utc.tai
# Calculate the difference between TAI and UTC in seconds
dutc = (current_time_tai - current_time_utc).sec
        
def mjd2bat(t):
    '''
    Convert Binary Atomic Time (BAT) to UTC.  At the ATNF, BAT corresponds to
    the number of microseconds of atomic clock since MJD (1858-11-17 00:00).
    Assumes "current" value of dUTC, so this  routine will fail for historic calculations
    '''
#    dutc = 37.
    bat = int(round((t*86400. + dutc) *1000000))
    return bat

def loadtrack(s, wrap, time1, az1, el1, time2, az2, el2):
    if debug: print("Load {} ({:.2f},{:.2f}) to {} ({:.2f},{:.2f})".format(time1.iso, degrees(az1), degrees(el1), time2.iso, degrees(az2), degrees(el2)))

    interval = (time2.mjd - time1.mjd)*24*60*60
    azRate = (az2-az1)/interval
    elRate = (el2-el1)/interval

    args = "{} AzElApp {:f} {:f} {:f} {:f} {:f} {:f}".format(wrap.asStr(), az1, az2, azRate, el1, el2, elRate)
    s.send_cmd('sscan', args, when=mjd2bat(time1.mjd))

def loadTLE(satellite, tle_dir):
    filename = os.path.join(tle_dir, satellite+".tle")
    if not os.path.isfile(filename):
        print(f"Error: '{filename}' does not exist, or is not plain file")
        sys.exit(1)
    with open(filename, 'r') as file:
        lines = file.readlines()
    if len(lines)!=3:
        print("Error: TLE file must be 3 lines, not {}".format(len(lines)))
        sys.exit(1)
    return EarthSatellite(lines[1], lines[2], lines[0], ts)

def calculateAzel(obs, satellite, start, stop, step):
    
    duration = (stop-start).sec
    (year, month, day, h, m, s) = start.ymdhms
    n = int(ceil(duration/step))+1
    duration = n*step

    t = ts.utc(year, month, day, h, m, s+range(0, duration, step))

    # Calculate the position of the satellite as seen from the observer's location
    difference = satellite - obs
    topocentric = difference.at(t)
    el, az, distance = topocentric.altaz()

    #ra, dec, distance = topocentric.radec()  # ICRF ("J2000")
    return [TLEpoint(time, az1, el1) for time, az1, el1 in zip(list(t.to_astropy().utc), list(az.radians), list(el.radians))]

def printMon(mon):
    print(f"\rServo state = {mon.servo_state:^10}  AzEl=({mon.azdeg:.1f},{mon.eldeg:.1f}) Error=({mon.delta_az_deg:.2f},{mon.delta_el_deg:.2f})     ", end="")

# Directory to load TLEs
tle_dir = os.getenv('TLE_DIR', os.path.expanduser('~/tle'))
# Check if the directory exists
if not os.path.isdir(tle_dir):
    print(f"Error '{tle_dir}' does not exist")
    sys.exit(1)
ts = load.timescale()    

Mopra = wgs84.latlon(-31.2678084, 149.0996449, 867.329)

filename = args.schedule

# Load the schedule
scans = []
with open(filename, 'r') as file:
    for line in file:
        # Remove any comments and strip whitespace
        line = line.split('#', 1)[0].strip()
        if not line:
            continue  # Skip blank lines
            
        # Split the line into parts
        parts = line.split()
        if len(parts) < 4: # Skip lines that don't have at least 4 parts
            print(f"Warning: Skipping {line}")
            continue  

        # Extract the values
        iwrap = 4
        name = parts[0]
        mode = parts[1]
        start = parts[2]
        stop = parts[3]

        if mode=='J2000' or mode=='AZEL':
            left = parts[4]
            right = parts[5]
            iwrap  = 6
        else:
            left = None
            right = None
            
        wrap = parts[iwrap] if len(parts) > iwrap else 'CLOSEST'

        # Create a Scan object and append it to the list
        scan = Scan(name=name, start=start, stop=stop, left=left, right=right, mode=mode, wrap=wrap)
        if scan.mode==Mode.TLE:
            scan.tle = loadTLE(scan.name, tle_dir)
        scans.append(scan)

# TODO - check scans in time order and no overlap

# Should we wait
startTime = scans[0].start - TimeDelta(15*60, format='sec')
if Time.now()<startTime:
    print("Schedule does not start till {}".format(sched[0][0].iso))
    print("Waiting till {}".format(startTime.iso))
    while Time.now()<startTime:
        time.sleep(5)
    print("Starting observations")
        
# Connect to antenna
s = LegacyServer((server,port), fake=args.fake)
s.connect()
s.allocate()
s.send_cmd('enable')
s.send_cmd('drvOn')

now = Time.now()

if scans[0].start<now:
    if scans[0].stop<now:
        print("Schedule already started - skipping ahead")
        scans.pop(0)
        while len(scans)>0:
            if scans[0].stop>now: break
            scans.pop(0)
    else:
        print("First scan already started")

if len(scans)==0:
    print("Schedule already finished")
    sys.exit(1)

print("Next scan at {}".format(scans[0].start))

update = TimeDelta(2 * u.second)  # Check stuff every 2 seconds

for scan in scans:
    print(f"Doing scan {scan.name}")
    now = Time.now()
    last_update = now
    if now > scan.stop:
        print(f"Skipping Scan {scan.name}, already finished")
        continue
    if scan.mode==Mode.J2000 or scan.mode==Mode.AZEL: # Move to source straight away
        print(f"Goto {scan.name}")
        s.goto(scan.wrap.asStr(), scan.mode.command(), str(scan.left.rad), str(scan.right.rad))

        while now<scan.stop:
            if now+update<scan.stop:
                sleepTime = update
            else:
                sleepTime = scan.stop-now
            time.sleep(sleepTime.sec)
            mon = s.get_mon()
            printMon(mon)
            now = Time.now()
        print()
    elif scan.mode==Mode.TLE:
        tlePoints = calculateAzel(Mopra, scan.tle, scan.start, scan.stop, args.tleupdate)

        if debug:
            for x in tlePoints:
                print('{} ({:.5f},{:.5f})'.format(x.time.iso, x.az, x.el))

        now = Time.now()
        if (tlePoints[0].time<now):
            print("Missed start - skip till next valid")
            tlePoints.pop(0)
            while len(tlePoints)>0:
                if tlePoints[0].time>now: break
                tlePoints.pop(0)
        if len(tlePoints)==0:
            print("Scan finished")
            continue

        # Go to first Az/El
        print("Slewing to ({:.1f},{:.1f})".format(degrees(tlePoints[0].az), degrees(tlePoints[0].el)))
        s.goto(scan.wrap.asStr(), 'AzElApp', tlePoints[0].az, tlePoints[0].el)
        time.sleep(6) 
        mon = s.get_mon()
        while now<scan.stop and mon.servo_state=='SLEWING':
            printMon(mon)
            time.sleep(1)
            mon = s.get_mon()
            now = Time.now()
        print()

        if now > scan.stop: continue
        
        # Repeat the process as the slew probably means some more skipped intervals
        now = Time.now()
        if (tlePoints[0].time<now):
            print("Missed start - skip till next valid")
            tlePoints.pop(0)
            while len(tlePoints)>0:
                if tlePoints[0].time>now: break
                tlePoints.pop(0)
        if len(tlePoints)==0:
            print("Scan finished")
            continue

        print("Starting at {}".format(tlePoints[0].time.iso))

        loadtimes = []
        i = 0
        while i<args.tlequeue and len(tlePoints)>1:
            loadtrack(s, scan.wrap, tlePoints[0].time, tlePoints[0].az, tlePoints[0].el, tlePoints[1].time, tlePoints[1].az, tlePoints[1].el)
            i += 1
            tlePoints.pop(0)
            loadtimes.append(tlePoints[1].time)

        now = Time.now()
        while now<scan.stop and len(tlePoints)>1:
            if debug: print("Next load", loadtimes[0].iso)
            while now<scan.stop and now<loadtimes[0]:
                mon = s.get_mon()
                printMon(mon)
                time.sleep(1)
                now = Time.now()
            if debug: print()
            loadtrack(s, scan.wrap, tlePoints[0].time, tlePoints[0].az, tlePoints[0].el, tlePoints[1].time, tlePoints[1].az, tlePoints[1].el)
            tlePoints.pop(0)
            loadtimes.pop(0)
            loadtimes.append(tlePoints[1].time)

        # Wait until end of scan
        while now<scan.stop:
            mon = s.get_mon()
            printMon(mon)
            time.sleep(1)
            now = Time.now()
        print()
        
    else:
        print("Unsupported mode type {}".format(scan.mode))
        sys.exit(1)

if args.stow:
    time.sleep(10)
    print("Stowing antenna")
    s.stow()

        
s.deallocate()
