from math import pi, cos, sin, acos, asin, atan2, floor, sqrt
from numpy import deg2rad, rad2deg
import numpy as np

import sys
if (sys.version_info < (3, 4)):
   raise Exception('Need  Python >= 3.4')

def posrad(rad):
    while rad<0:
        rad += 2*pi
    while rad>2*pi:
        rad -= 2*pi
    return(rad)

def rad2hour(rad):
    if rad is None:
        return None
    else:
        return (posrad(rad)/2/pi)*24

def Jy2Str(Jy):
    if Jy>100:
        jyStr = "{:.0f} Jy".format(Jy)
    elif Jy>1:
        jyStr = "{:.1f} Jy".format(Jy)
    elif Jy>0.050:
        jyStr = "{:.1f} mJy".format(Jy*1000)
    elif Jy>0.001:
        jyStr = "{:.2f} mJy".format(Jy*1000)
    elif Jy>0.0001:
        jyStr = "{:.3f} mJy".format(Jy*1000)
    else:
        jyStr = "{:.1f} uJy".format(Jy*1e6)
    return(jyStr)

from enum import Enum
class Array(Enum):
    LBA = 1
    EVN = 2
    VLBA = 3
    EAVN = 4
    SKA = 5
    SEAVN = 6
    
# Cartesian Coordinates

class Cartesian:

    def __init__(self, x, y, z):
        self.x = x
        self.y = y
        self.z = z

    # Return length of vector
    def length(self):
        return sqrt(self.x**2 + self.y**2 + self.z**2)

    # Return Latitude and Longitude of Vector
    def longlat(self):
        lat = atan2(self.z, sqrt(self.x**2 + self.y**2))
        long = atan2(self.y,self.x)
        return(long, lat)

    # Offset the coordinates by ADDING the offsets
    def offset(self, dx, dy, dz):
        self.x += dx
        self.y += dy
        self.z += dz

class Telescope:
### Simple Class for a Telescope

    def __init__(self, name, code, array, el_limit, SEFD, location=[]):
        R_EARTH = 6378.137e3

        self.name = name
        self.code = code
        self.array = array
        self.el_limit = deg2rad(el_limit)
        self.sefd = SEFD # Jy
        self.maxBandwidth = 128 * 1e6 # MHz
        self.rise = None
        self.set = None

        if len(location)==2:
            self.longitude = deg2rad(location[0])
            self.latitude = deg2rad(location[1])

            phi = pi/2 - self.latitude
            x = R_EARTH * sin(phi) * cos(self.longitude)
            y = R_EARTH * sin(phi) * sin(self.longitude)
            z = R_EARTH * cos(phi)
            self.cartesian = Cartesian(x,y,z)
        elif len(location)==3:
            self.setCartesianXYZ(location[0], location[1], location[2])
        elif len(location)==0:
            raise('Must pass location')
        else:
            raise('Unknown location with {} elements'.format(len(location)))
        
    def __str__(self):
        return "{0}: {1:.3f} {2:.3f} {3:.1f}".format(
            self.name, rad2deg(self.longitude), rad2deg(self.latitude), rad2deg(self.el_limit))

    # Return the 3D cartesian position of the telescope, in m
    def Cartesian(self):
        return self.cartesian

    # Return the latitude
    def Latitude(self):
        return self.latitude

    # Return the longitude
    def Longitude(self):
        return self.longitude

    #  Set the 3D cartesian position of the telescope, in m
    def setCartesian(self, loc):
        self.cartesian = loc
        (long,lat) = self.cartesian.longlat()
        self.longitude = long
        self.latitude = lat

    #  Set the 3D cartesian position of the telescope, in m
    def setCartesianXYZ(self, x, y, z):
        self.setCartesian(Cartesian(x, y, z))

    # Calculate the rise and set time for the source, and set the class "rise" and "set" values in the object
    def calcrise(self, ra, dec): # Return GST rise and set for a source
        z_limit = pi/2-self.el_limit; # Zenith limit

        # Does the source ever rises 
        z = acos(sin(self.latitude)*sin(dec) + cos(self.latitude)*cos(dec)) # Highest point
        if (z > z_limit): 
            self.rise = None
            self.set = None
            return
        
        # or is circumpolar
        z = acos(sin(self.latitude)*sin(dec) - cos(self.latitude)*cos(dec)) # Lowest point
        if (z<z_limit): 
            self.rise = 0
            self.set = 2*pi
            return

        ha = acos((cos(z_limit) - sin(self.latitude)*sin(dec))/(cos(self.latitude)*cos(dec)))
        self.rise = posrad(ra-ha-self.longitude)
        self.set= posrad(ra+ha-self.longitude)
        
    # Is the source  up for the given GST (rise time must have been previously set with calcrise)
    def isUp(self, gst):
        if self.rise==None: return None

        if self.rise<self.set: # Non-wrapped case
            if gst>self.rise and gst<self.set:
                return True
            else:
                return False
        elif gst<self.set or gst>self.rise: # Wrapped case
            return True
        else: 
            return False
        
    # Return a list of times when the source is above the elevation limit, with a fixed step size
    def upTimes(self, RA, Dec, step):
        # RA     Radians
        # Dec    Radians
        # step   Step size for GST time calculation *in radians*
        
        allgst = [x*step for x in range(floor(2*pi/step))]
        self.calcrise(RA, Dec)
        return [gst for gst in allgst if (self.isUp(gst))]

    # Calculate times when source is visible to telescope. Set object .gst value for later functions to use
    def calcUp(self, RA, Dec, step):
        self.calcrise(RA, Dec)
        
        allgst = [x*step for x in range(int(floor(2*pi/step)))] # 24 hours
        self.gst = np.array([gst for gst in allgst if (self.isUp(gst))])
    
    # Returns the AzEl of a source at (RA, Dec) , given the passed GST
    def AzEl(self, RA, Dec, gst): 
        lst = posrad(gst + self.longitude)
        ha = posrad(lst-RA)
        sphi = sin(self.latitude)
        cphi = cos(self.latitude)
        sha = sin(ha)
        cha = cos(ha)
        sdec = sin(Dec)
        cdec = cos(Dec)
        az = atan2(-sha,-cha*sphi+sdec*cphi/cdec)
        el= asin(cha*cdec*cphi + sdec*sphi)
        
        return(az, el)
    
    # Return two Lists (Az and El) values, for the given set of GST values (passed as a list)
    def calcAzEl(self, RA, Dec, gsts=None):
        Az = []
        El = []

        if gsts is None: gsts = self.gst
        
        for gst in gsts:
            (Az1, El1) = self.AzEl(RA, Dec, gst)
            Az.append(Az1)
            El.append(El1)
        return(Az, El)

    # Is the telescope in this Array
    def isArray(self, array):
        for a in self.array:
            if a == array:
                return True
        return False

class Baseline:
# Simple Class for Baseline
    C = 299792458;
    
    def __init__(self, ant1, ant2):
        self.ant1 = ant1
        self.ant2 = ant2

        self.name = ant1.name+"->"+ant2.name
        self.xyz1 = ant1.Cartesian()
        self.xyz2 = ant2.Cartesian()
        self.baseline = Cartesian(self.xyz1.x-self.xyz2.x,self.xyz1.y-self.xyz2.y,self.xyz1.z-self.xyz2.z)
        self.length = self.baseline.length()
        self.gsts = None # List of times souce is up on this baseline
        self.u = None
        self.v = None
        
    # Calculate times when source is visible to both telescopes. Set object .gst value for later functions to use
    def calcUp(self, RA, Dec, step):
        self.ant1.calcrise(RA, Dec)
        self.ant2.calcrise(RA, Dec)
        
        allgst = [x*step for x in range(int(floor(2*pi/step)))] # 24 hours
        self.gst = np.array([gst for gst in allgst if (self.ant1.isUp(gst) and self.ant2.isUp(gst))])

    # Return the baseline UV values (in wavelengths) for a specific GST time
    def uv(self, RA, Dec, gst, wavelength=None):
        ha = gst + self.ant1.longitude - RA

        ha = RA - gst  # This calculation of Baseline hour angle needs to be checked
        sHa = sin(ha)
        cHa = cos(ha)

        sDec = sin(Dec)
        cDec = cos(Dec)

        if wavelength is None: wavelength = 1000; # Return km

        u = (-self.baseline.x*sHa + self.baseline.y*cHa)/wavelength
        v = (-self.baseline.x*cHa*sDec - self.baseline.y*sHa*sDec + self.baseline.z*cDec)/wavelength
    
        return(u,v)

    # Return baseline delay (in seconds?) for a specific gst
    def delay(self, RA, Dec, gst):
        ha = -(gst + self.ant1.longitude - RA)

        #ha = RA - gst  # This calculation of Baseline hour angle needs to be checked
        sHa = sin(ha)
        cHa = cos(ha)
        sDec = sin(Dec)
        cDec = cos(Dec)
        
        d = self.baseline.x*cHa*cDec + self.baseline.y*sHa*cDec + self.baseline.z*sDec
        #d = (self.baseline.x*cHa - self.baseline.y*sHa)*cDec + self.baseline.z*sDec

        return(d/self.C*1e6)

    # Return the baseline rate (Hz) for a specific GST time. If no wavelength passed, return delay rate
    def rate(self, RA, Dec, gst, wavelength=None):
        ha = gst + self.ant1.longitude - RA

        ha = RA - gst  # This calculation of Baseline hour angle needs to be checked
        sHa = sin(ha)
        cHa = cos(ha)

        #sDec = sin(Dec)
        cDec = cos(Dec)

        if wavelength is None: wavelength = self.C # Return seconds
        
        rate = (-self.baseline.x*sHa*cDec + self.baseline.y*cHa*cDec) / (12.0*3600.0) * pi / wavelength
        
        return(rate)

    
    # Calculate the UV values and set the object .u and .v values for the last set of GSTS
    def UVtrack(self, RA, Dec, gsts, wavelength=None):
        u = []
        v = []
        for gst in gsts:
            (thisu, thisv) = self.uv(RA, Dec, gst, wavelength)
            u.append(thisu)
            v.append(thisv)
            
        self.u = np.array(u)
        self.v = np.array(v)

    # Calculate baseline sensitivity in Jy for a given integration and bandwidth
    def sensitivity(self, integration, bandwidth, dualpol=True):
        if (dualpol):
            polfact = 2
        else:
            polfact = 1

        #bandwidth = min(bandwidth, self.ant1.maxBandwidth, self.ant2.maxBandwidth)
        if (self.ant1.sefd<0) or (self.ant2.sefd<0): return(0)
        
        return sqrt(self.ant1.sefd * self.ant2.sefd) / (0.88*sqrt(polfact*bandwidth*integration))

    # Is either or both telescopes in the specified array
    def isArray(self, array, both=True):
        if both:
            return self.ant1.isArray(array) and self.ant2.isArray(array)
        else:
            return self.ant1.isArray(array) or self.ant2.isArray(array)

    # Is passed telescope in this baseline (check both code and name)
    def isTel(self, tel):
        if self.ant1.name==tel or self.ant2.name==tel:
            return True
        elif self.ant1.code==tel or self.ant2.code==tel:
            return True
        else:
            return False
   
telescopes = {
    # "Antenna Name" : (AntID, Array, ElevationLimit, SEFD, [X,Y,Z]  (or [Long,Lat])xs
    "ATCA"   :    ('At', Array.LBA, 12,  68,  [-4751639.85972, 2791700.35670, -3200491.11339]), # L, S, C, 6.7, X, K, Q, W
    "Mopra"  :    ('Mp', Array.LBA, 12, 240,  [-4682769.05850, 2802619.04217, -3291759.33837]), # L, S, C, 6.7, X, K, Q, W
    "Parkes" :    ('Pa', Array.LBA, 30,  40,  [-4554232.4864,  2816758.8662,  -3454035.0137]),  # L, S, C, 6.7, X, K
    "Hobart" :    ('Ho', Array.LBA, 16,  470, [-3950237.3590,  2522347.6804,  -4311561.8790]),  # L, S, C, 6.7, X, K
    "Ceduna" :    ('Cd', Array.LBA, 10, 1000, [-3753442.7457,  3912709.7530,  -3348067.6095]),  # L, S, C, 6.7, X, K
    "ASKAP":      ('Ak', Array.LBA, 15.3,2000,[-2556088.333, 5097405.644, -2848428.244]),       # L
    "Effelsberg": ('Ef', Array.EVN, 9,-1,     [4033947.2355, 486990.7943,  4900431.0017]),
    "Medicina":   ('Mc', Array.EVN,5,-1,      [4461369.6718,   919597.1349,  4449559.3995]),
    "Noto":       ('Nt',    Array.EVN,5,-1,   [4934562.8175,  1321201.5528,  3806484.7555]),
    "Jodrell":    ('Jb', Array.EVN, 3,-1,     [3822626.0400, -154105.6500, 5086486.0400]),
    "Westerbork": ('Wb', Array.EVN,15,-1,     [3828445.6590,  445223.6000, 5064921.5680]),
    "Torun":      ("Tr", Array.EVN, 3, -1, [3638558.5100, 1221969.7200, 5077036.7600]),
    "Hartebeesthoek":('Hh', [Array.EVN,Array.LBA] ,15,-1,[5085442.7608,  2668263.8091, -2768696.7292]),
    "Shanghai" :  ('Sh', [Array.EVN,Array.EAVN] ,5,-1,[-2847698.0296,  4659872.5718,  3283958.5330]),
    "Urumqi" :    ('Ur', [Array.EVN,Array.EAVN], 5,-1,[228310.1811,  4631922.9018,  4367064.2110]),
    "DSS43" :     ('Ti', Array.LBA, 6,23,[-4460894.7273,  2682361.5296, -3674748.4238]),
    "DSS63" :     ('Ro', Array.EVN,6,-1,[4849092.6814,  -360180.5350,  4115109.1298]),
    "Sardinia":   ('Sd', Array.EVN,5,-1,[4865182.7660,   791922.6890,  4035137.1740]),
    "Kunming":    ('Km', [Array.EVN,Array.EAVN], 5,-1,[-1281152.8850,  5640864.3903,  2682653.4583]),  # S, 6.7, X
    "Miyun":      ('My', [Array.EVN,Array.EAVN], 7,-1,[-2201304.5880,  4324789.2160,  4125367.9130]),
    "Katherine":  ('Ke', Array.LBA, 5, -1, [-4147354.6913,  4581542.3772, -1573303.1565]), # S, C, X
    "Yarragadee": ('Yg', Array.LBA, 5, -1, [-2388896.1890,  5043350.0019, -3078590.8037]), # S, C, X
    "Warkworth":  ('Wa', Array.LBA, 7,-1,[-5115425.818,   477880.248, -3767042.055]),      # S, C, X
    "SKA":        ('Sk', Array.SKA,  10, -1, [22.13130556, -30.96841667]),
    "Ghana":      ('Gh', None,       15, -1, [-0.305154, 5.750509]),
    "Nobeyama":   ('No', Array.EAVN, 10, -1, [-3871025.4987, 3428107.3984, 3724038.7361]),    # K, Q
    "Takahagi":   ('Ta', Array.EAVN, 10, -1, [-3961882.0160, 3243372.5190, 3790687.4570]),    # K
    "Mizusawa":   ('Vm', Array.EAVN, 10, -1, [-3857244.9089, 3108782.9415, 4003899.1770]),    # K, Q
    "Iriki":      ('Vr', Array.EAVN, 10, -1, [-3521719.8633, 4132174.6847, 3336994.1305]),    # K, Q
    "Ogasawara":  ('Vo', Array.EAVN, 10, -1, [-4491068.4253, 3481545.2331, 2887399.7871]),    # K, Q 
    "Ishigaki":   ('Vs', Array.EAVN, 10, -1, [-3263995.2318, 4808056.3788, 2619948.6690]),    # K, Q
    "Hitachi":    ('Hi', Array.EAVN,  5, -1, [-3961789.1650, 3243597.5310, 3790597.7000]),    # 6.7, x (K?)
    "Yamaguchi":  ('Ym', Array.EAVN,  5, -1, [-3502544,	3950966.4,	3566381.2]),              # 6.7, X
    "Usuda":      ('Ud', Array.EAVN,  8, -1, [-3855355,  3427427.6,  3740971.3]),             # L, S, 6.7, X
    "Yonsei":     ('Ky', [Array.EAVN,Array.EVN], 5, -1, [-3042280.9137, 4045902.7164, 3867374.3544]), # K, Q, W
    "Ulsan":      ('Ku', [Array.EAVN,Array.EVN], 5, -1, [-3287268.7200, 4023450.0790, 3687379.9390]), # K, Q, W
    "Tamna":      ('Kt', [Array.EAVN,Array.EVN], 5, -1, [-3171731.7246, 4292678.4575, 3481038.7330]), # K, Q, W
    "Sejong":     ('Sj', [Array.EAVN,Array.EVN], 5, -1, [-3110079.9600, 4082066.7340, 3775076.8320]), # S, X K, Q
    "Tianma":     ('T6', Array.EAVN, 7, -1, [-2826708.6030, 4679237.0770, 3274667.5510]),
    "Nanshan":    ('Na', Array.EAVN, 10, -1, [ 87.178169, 43.471525]),
    "TNRT":       ('Th', Array.EAVN, 5, -1, [99.216944,  18.864278]),
    "Songkhla":   ('So', Array.EAVN, 10, -1, [100.608347, 7.156891]),
    "GMRT":       ('Gm', None, 15, -1, [ 74.049300, 19.094800]),
    "FAST":       ('Fa', None, 50, -1, [ 106.856580, 25.652920]),
    "Indonesia":  ('In', None, 15, -1, [107.411357, -6.522294]), # Not sure what this is
    "JATILU32":   ('Jh', Array.SEAVN,  5, -1, [ -1896321.6268,  6046883.2955, -719777.6803]),
  	"TIMAU32":    ('Tu', Array.SEAVN,  5, -1, [ -3513232.1112,  5219086.2946, -1052657.7758]),
    "Malaysia":   ('Ml', Array.SEAVN,  5, -1, [ -1290006.7579,  6236667.7246,  347159.2344]),
    "BR":         ('Br', [Array.VLBA], 3, -1, [ -2112065.2062, -3705356.5048, 4726813.6759]),
    "FD":         ('Fd', [Array.VLBA], 3, -1, [ -1324009.3266, -5332181.9547, 3231962.3949]),
    "GBT":        ('Gb', [Array.VLBA], 6, -1, [   882589.4102, -4924872.3416, 3943729.4062]),
    "HN":         ('Hn', [Array.VLBA], 3, -1, [  1446374.8658, -4447939.6774, 4322306.1838]),
    "KP":         ('Kp', [Array.VLBA], 3, -1, [ -1995678.8402, -5037317.6968, 3357328.0251]),
    "LA":         ('La', [Array.VLBA], 3, -1, [ -1449752.5839, -4975298.5757, 3709123.8459]),
    "MK":         ('Mk', [Array.VLBA], 3, -1, [ -5464075.1847, -2495248.1055, 2148297.3649]),
    "NL":         ('Nl', [Array.VLBA], 3, -1, [  -130872.4987, -4762317.0925, 4226851.0014]),
    "OV":         ('Ov', [Array.VLBA], 3, -1, [ -2409150.4018, -4478573.1180, 3838617.3385]),
    "SC":         ('Sc', [Array.VLBA], 3, -1, [  2607848.6379, -5488069.5358, 1932739.7326]),
}
   
def getTelescope(tel):
    if tel in telescopes:
        t = telescopes[tel]
        if isinstance(t[1], (list, tuple, np.ndarray)):  # ie not scalar
            thisArray = t[1]
        else:
            thisArray = [t[1]]
        return Telescope(tel,t[0],thisArray,t[2],t[3],location=t[4])
    else: # Try and match on antenna ID
       for ant, t in telescopes.items():
           if t[0]==tel:
               if isinstance(t[1], (list, tuple, np.ndarray)):  # ie not scalar
                   thisArray = t[1]
               else:
                   thisArray = [t[1]]
               return Telescope(ant,t[0],thisArray,t[2],t[3],location=t[4])
        
    print("Could not find {}".format(tel))
    return None

def telescopeList():
    return list(telescopes.keys())

# Return a list of baselines from a passed list of telescopes
def createBaselines(antennas, refAnts=None):
    baselines = []
    if refAnts is None:
        for i in range(len(antennas)):
            for j in range(i+1, len(antennas)):
                baselines.append(Baseline(antennas[i], antennas[j]))
    else:
        for i in range(len(refAnts)):
            for j in range(len(antennas)):
                baselines.append(Baseline(refAnts[i], antennas[j]))
    return baselines

def imageRMS(baselines, bandwidth=None, integration=None, dualpol=True, weight=2):
  weight *= -1;
  C = 0
  sum = 0

  for b in baselines:
    bRMS = b.sensitivity(bandwidth,integration,dualpol)
    w = bRMS ** weight
    C += w
    sum += w*w*bRMS*bRMS;
  
  return(sqrt(sum)/C)

