#!/usr/bin/env python3

# -*- coding: utf-8 -*-
"""
Created on Fri Feb 17 11:54:35 2017

@author: jimlovell
"""

import numpy as np
from matplotlib import pyplot as plt
from mpl_toolkits.basemap import Basemap
from pyproj import Proj, transform
import sys, getopt, math

AntNames = {'At' : 'ATCA',
            'Ho' : 'Hobart',
            'Mp' : 'Mopra',
            'Cd' : 'Ceduna',
            'Pa' : 'Parkes',
            'Ak' : 'ASKAP',
            'Hh' : 'Hartebeesthoek',
            'Yg' : 'Yarragadee',
            'Ke' : 'Katherine',
            'Ti' : 'Tidbinbilla',
            'Ww' : 'Warkworth'}

def count_stations(name,D):
    file = open(name, 'r')
    while True:
        line = file.readline()
        if not line: break
        line = line.rstrip()
        if str(line) in D:
#            print('increment {}'.format(str(line)))
            D[str(line)] += 1
        else:
#            print('New {}'.format(str(line)))
            D[str(line)] = 1
        
    file.close()
#    print('D = {}'.format(D))
    return(D)
    
def read_station_positions(name):
#    X = {}
#    Y = {}
#    Z = {}
    dlat = {}
    dlon = {}
    file = open(name, 'r')
    ecef = Proj(proj='geocent', ellps='WGS84', datum='WGS84')
    lla = Proj(proj='latlong', ellps='WGS84', datum='WGS84')

    while True:
        line = file.readline()
        if not line: break
        line = line.strip()
        if not line.startswith("*"):
#            print('line: {}'.format(line))
            arr = line.split()
            (st,stn,x,y,z,code,lon,lat,source) = arr[0:9]
            lon2,lat2,alt = transform(ecef, lla, x, y, z)
            dlat[str(st)] = float(lat2)
            dlon[str(st)] = float(lon2)            
    return(dlon,dlat)

    
def count_stations_from_cat_list(name, counts, st_string):
    file = open(name, 'r')

    while True:
        line = file.readline()
        if not line: break
        line = line.strip()
        if not line.startswith("*"):
            arr = line.split()
            (st,stn,x,y,z,code,lon,lat,source) = arr[0:9]
            counts[str(st)] = st_string.count(str(st))
    return(counts)

    
def read_master_schedule(name):
    # name = filename
    file = open(name, 'r')
    all_station_string = ""
    while True:
        line = file.readline()
        if not line: break
        line = line.strip()
        if line.startswith("|"):
            # it contains something useful
            arr = line.split("|")
            (dummy,name,code,date,doy, time_ut, dur, stations,sked,corr,status,pf,dbc,subm,delay,mk4num) = arr[0:16]
            stations_observed= stations.split()[0]
            # Va is the abbreviation for the VLBA. Expand...
            stations_observed = stations_observed.replace("Va", "BrFdHnKpLaMkNlPtOvSc")
            all_station_string = "{}{}".format(all_station_string, stations_observed)
    return(all_station_string)

def mappit(m, station_counts, lat, lon):
    water_colour = '#d5e0f2'
    land_colour = '#9ac195'
    station_colour = '#6234f9'
    utas_colour = '#f93434'
    
#    station_colour = '#ff7051'
    utas_colour = '#ff7051'
    utas_colour = station_colour

    m.drawmapboundary(fill_color=water_colour)
    m.fillcontinents(color=land_colour,lake_color=water_colour)
    m.drawmeridians(np.arange(0,360,30),labels=[0,0,0,1],fontsize=6, color='grey')
    m.drawparallels(np.arange(-90,90,15),labels=[1,0,0,0],fontsize=6,color='grey')
    lats = []
    lons = []
    size = []
    for k in sorted(station_counts):
        if station_counts[k] > 0:
            lats.append(lat[k])
            mylon=lon[k]
            if mylon > 180.0:
                mylon = mylon -360.0
            lons.append(mylon+180.0)
            x,y = m(mylon,lat[k])
            size = math.sqrt(station_counts[k])
            size = ((station_counts[k]-1.0)/12)+4
            if False and k in ['Ke', 'Yg']:
                m.plot(x,y, marker='D', color='#aa44ff', markersize=size) 
            elif k in ['Ak']:
               m.plot(x,y, marker='o', color='#ff6666', markersize=size) 
            else:
                m.plot(x,y, marker='o',color=station_colour,  markersize=size)

def draw_baselines(mapname, lat, lon, stations = ['Hb','Ke'], just_to_first=False, linewidth=2, color='k', linestyle='-'):
    if just_to_first:
        array_length_to_stop = len(stations)-1
        stations.reverse()
    else:
        array_length_to_stop = 1
        
    while len(stations) > array_length_to_stop:
        reference_antenna = stations.pop()
        lat1 = lat[reference_antenna]
        lon1 = lon[reference_antenna]
        for k in stations:
            lat2 = lat[k]    
            lon2 = lon[k]
            if lon1 > 180.0:
                lon1 -= 360.0
            if lon2 > 180.0:
                lon2 -= 360.0
            lats = []
            lons = []
            lats.append(lat1)
            lats.append(lat2)
            lons.append(lon1)
            lons.append(lon2)
            print("inputs to line plots\nlats: {}\nlons: {}\n".format(lats,lons));
            x,y = mapname(lons,lats)
            print("inputs to line plots\nx: {}\ny: {}\n".format(x,y));
            mapname.plot(x,y,linestyle,markersize=12, linewidth=linewidth, color=color, markerfacecolor='b')

    
def main(argv):
    # file contains a single column of two-letter station codes, one per observation
    usage = "lba_map.py -a <array_string> -o <outfile> -d <dpi>"

    st_string = 'HoHhKeYgCdAkMpTiPaAtWw'
    output_file = "lba_map.png"
    dpi = 200
    fsize = 6
    onscreen = False
    try:
        opts, args = getopt.getopt(argv,"hxa:o:d:",["array=","ofile=","dpi=","fontsize="])
    except getopt.GetoptError:
        print(usage)
        sys.exit(2)
    for opt, arg in opts:
        if opt == '-h':
            print(usage)
            sys.exit()
        if opt == '-x':
            onscreen = True
        if opt in ("-a", "--array"):
            st_string = arg
        if opt in ("-o", "--ofile"):
            output_file = arg
        if opt in ("-d", "--dpi"):
            dpi = float(arg)
        if opt in ("--fontsize"):
            fsize = float(arg)
    
    st_counts2={}
    st_counts2 = count_stations_from_cat_list('position.cat', st_counts2, st_string)
    (lon, lat) = read_station_positions('position.cat')

    fig1 = plt.figure(figsize=((15.0/2.0),(9.0/2.0)))
#    m = Basemap(width=17000000,height=9000000,
#                projection='lcc',lat_0=-40,lon_0=85.,
#                resolution ='i',area_thresh=1000.)

    # Australia, NZ and South Africa
    m = Basemap(projection='merc',llcrnrlat=-55,urcrnrlat=10,\
                    llcrnrlon=10,urcrnrlon=185,lat_ts=20,resolution='c')

    # Australia only
    #m = Basemap(projection='merc',llcrnrlat=-45,urcrnrlat=-8,\
    #                llcrnrlon=105,urcrnrlon=158,lat_ts=20,resolution='c')

    # Australia and NZ 
    #m = Basemap(projection='merc',llcrnrlat=-50,urcrnrlat=5,\
    #            llcrnrlon=90,urcrnrlon=185,lat_ts=20,resolution='c')
    
    mappit(m, st_counts2, lat, lon)

    # make array of station 2-letter codes
    station_arr = [st_string[i:i+2] for i in range(0, len(st_string), 2)]

    # label stations
    #fsize = 6
    for k in station_arr:
        x,y = m(lon[k],lat[k])
        if (str(k) in AntNames):
            AntLabel = AntNames[k]
        else:
            AntLabel = str(k)

        if k in ['Hh', 'Mp', 'Ak']:
#            plt.text(x+100000,y+80000,str(k),fontsize=fsize)
#            plt.text(x+150000,y,str(k),fontsize=fsize,horizontalalignment='left')
            plt.text(x+150000,y,' '+AntLabel,fontsize=fsize,horizontalalignment='left')
        else:
            plt.text(x-50000,y,AntLabel+' ',fontsize=fsize,horizontalalignment='right')


    # draw baselines

#    draw_baselines(mapname=m, lat=lat, lon=lon, 
#                   stations=['Hb', 'Ke', 'Ht', 'Yg'], 
#                   color='b', linestyle='-')
#    draw_baselines(just_to_first=True, 
#                   mapname=m, lat=lat, lon=lon, 
#                   stations=['Yg', 'Ke', 'Hb', 'Ht'], 
#                   color='r', linestyle='--')


    F = plt.gcf()
    
    # Now check everything with the defaults:
    DPI = F.get_dpi()
    DefaultSize = F.get_size_inches()
    F.set_size_inches( (DefaultSize[0]*2, DefaultSize[1]*2) )
    Size = F.get_size_inches()
    
    F.set_size_inches( DefaultSize )# reset the size
    Size = F.get_size_inches()

    if onscreen:
        plt.show()
    else:
        F.savefig(output_file, dpi=dpi)
    
if __name__ == '__main__':  
    main(sys.argv[1:])
