#!/usr/bin/python

#core imports
from __future__ import print_function
from __future__ import division
from builtins import str
from builtins import range
from past.utils import old_div
import argparse
import sys
import os
import math
import re

#non-core imports
#set the plotting back-end to 'agg' to avoid display
import numpy as np

import matplotlib
matplotlib.use("Agg", warn=False)

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
#matplotlib.rcParams.update({'savefig.dpi':300})
import matplotlib.colors as mcolors
import matplotlib.cm as cmx
import matplotlib.colorbar as mcbar
import pylab


#HOPS module imports
import vpal.processing
import vpal.utility
import vpal.fringe_file_manipulation

import mk4b


################################################################################

def main():
    # usage_text = '\n az_el_amp_plot.py [options] <control-file> <ref-station> <stations> <experiment-directory>' \
    #              '\n e.g.: az_el_amp_plot.py ./cf_GHEVY_ff G HEV ./'
    # parser = optparse.OptionParser(usage=usage_text)

    parser = argparse.ArgumentParser(
        prog='az_el_amp_plot.py', \
        description='''utility for plotting az-el dependence of amplitude''' \
        )

    parser.add_argument('control_file', help='the control file to be applied to all scans')
    parser.add_argument('ref_station', help='single character code for the reference station')
    parser.add_argument('stations', help='concatenated string of single character codes for all stations to be fringe fit')
    #parser.add_argument('pol_product', help='the polarization-product to be fringe fit')
    parser.add_argument('experiment_directory', help='relative path to directory containing experiment data')

    parser.add_argument('-n', '--numproc', type=int, dest='num_proc', help='number of concurrent fourfit jobs to run, default=1', default=16)
    parser.add_argument('-c', '--channels', dest='channels', help='specify the channels to be used, default=abcdefghijklmnopqrstuvwxyzABCDEF.', default='abcdefghijklmnopqrstuvwxyzABCDEF')
    parser.add_argument('-s', '--snr-min', type=float, dest='snr_min', help='set minimum allowed snr threshold, default=30.', default=30)
    parser.add_argument('-d', '--dtec-threshold', type=float, dest='dtec_thresh', help='set maximum allowed difference in dTEC, default=1.', default=1.0)
    parser.add_argument('-q', '--quality-limit', type=int, dest='quality_lower_limit', help='set the lower limit on fringe quality (inclusive), default=3.', default=3)
    parser.add_argument('-p', '--progress', action='store_true', dest='use_progress_ticker', help='monitor process with progress indicator', default=True)
    parser.add_argument('-b', '--begin-scan', dest='begin_scan_limit', help='limit the earliest scan to be used e.g 244-1719', default="000-0000")
    parser.add_argument('-e', '--end-scan', dest='end_scan_limit', help='limit the latest scan to be used, e.g. 244-2345', default="999-9999")
    #parser.add_argument('-r', '--remove-outliers', dest='remove_outlier_nsigma', help='remove scans which are n*sigma away from the mean, default=0 (off)', default=0.0 )
    #parser.add_argument('-z', '--z-axis', dest='z_axis', help='select color axis, options are: amp, dtec, or none, default=none', default='none' )
    #parser.add_argument('-L', '--label', dest='exper_label', help='experiment label, eg vr2201, default=none', default='none' )

    args = parser.parse_args()

    #print('args: ', args)

    control_file = args.control_file
    ref_station = args.ref_station
    stations = args.stations
    polprod = ['I'] #['XX','XY','YX','YY']
    exp_dir = args.experiment_directory

    abs_exp_dir = os.path.abspath(exp_dir)
    exp_name = os.path.split(os.path.abspath(exp_dir))[1]
    #rnsigma = float(args.remove_outlier_nsigma)

    #exper_label = args.exper_label
    
    print('Calculating baselines')
    
    #determine all possible baselines
    baseline_list = vpal.processing.construct_valid_baseline_list(abs_exp_dir, ref_station, stations, network_reference_baselines_only=True)

    print('Baselines:', baseline_list)
    
    qcode_list = []
    for q in list(range(args.quality_lower_limit, 10)):
        qcode_list.append( str(q) )

    #needed for plot-naming
    control_file_stripped = re.sub('[/\.]', '', control_file)

    #default output filename
    plot_name = "./az_el_amp_plot_" + ref_station + "_" + exp_name

    amps = list()
    az = list()
    el = list()
        
    # bline is a two-element string of [reference_station, other_station]
    for bline in baseline_list:

        if not os.path.isfile(os.path.abspath(control_file)):
            print("could not find control file: ", control_file)
            sys.exit(1)

        #need to:
        #(1) collect the amplitudes for each pol-product from each scan
        #(2) apply the snr, and quality code cuts
        #(3) build array for each pol-product of parallactic angle and (normalized) amplitude
        #(5) create plot with pol-products for each possible baseline

        ################################################################################
        #collect/compute fringe files, and apply cuts
        set_commands = "set gen_cf_record true"
        ff_list = vpal.processing.load_and_batch_fourfit( \
            os.path.abspath(exp_dir), bline[0], bline[1], os.path.abspath(control_file), set_commands, \
            num_processes=args.num_proc, start_scan_limit=args.begin_scan_limit, \
            stop_scan_limit=args.end_scan_limit, pol_products=polprod, use_progress_ticker=args.use_progress_ticker \
        )

        print('For baseline '+bline+' there are '+str(len(ff_list))+' fringe files')

        
        #apply cuts
        filter_list = []
        filter_list.append( vpal.utility.DiscreteQuantityFilter("quality", qcode_list) )
        filter_list.append( vpal.utility.ContinuousQuantityFilter("snr", args.snr_min, 500) )
        vpal.utility.combined_filter(ff_list, filter_list)

        if len(ff_list) == 0:
            print("Error: no fringe files available after cuts, skipping baseline: ", bline)

        else:

            # how many unique, complete scans?
            #polarization = []
            #for ff in ff_list:
            #    scans.append(ff.scan_id)
            #    #polarization.append(ff.polarization)
                
            #unique_scans, counts = np.unique(scans, return_counts=True)
            
            # build an array to hold the amplitudes for each pol product
            #amps = -1*np.ones((len(unique_scans),4))
            #az = np.zeros((len(unique_scans),4))
            #el = np.zeros((len(unique_scans),4))
                        
            
            for ff in ff_list:

                # find the root files
                root_suffix = ff.filename[-6:]

                amps.append(ff.amp)
                
                station_data_file = os.path.join( os.path.dirname(ff.filename), ref_station + ".." + root_suffix)
                sdata = mk4b.mk4sdata(station_data_file)

                az.append(sdata.model[0].t303[0].contents.azimuth[0])
                el.append(sdata.model[0].t303[0].contents.elevation[0])
                #el[np.where(unique_scans == ff.scan_id)] = sdata.model[0].t303[0].contents.elevation[0]
                


    #for ii in range(len(scan_dPar)):
    #    if any(amps[ii,:]==-1):
    #        print('This scan has missing data!', ff_list[ii].scan_id, amps[ii,:])
    #        sys.exit()



    fig_width_pt = 600  # Get this from LaTeX using \showthe\columnwidth
    inches_per_pt = 1.0/72.27               # Convert pt to inch
    #golden_mean = (2.236-1.0)/2.0         # Aesthetic ratio
    golden_mean = 0.7
    fig_width = fig_width_pt*inches_per_pt  # width in inches
    fig_height = fig_width*golden_mean      # height in inches
    fig_size = [fig_width,fig_height]


    matplotlib.rcParams.update({'savefig.dpi':350,
                                'text.usetex':True,
                                'figure.figsize':fig_size,
                                'font.family':"serif",
                                'font.serif':["Times"]})
        
    
    fig = pylab.figure(np.random.randint(0,1000))
    ax = plt.subplot(111)

    cmap=cmx.get_cmap('viridis')
    normalize = mcolors.Normalize(vmin=min(np.log10(amps)), vmax=max(np.log10(amps)))
    pylab.scatter(az, el, s=4, c=np.log10(amps), cmap=cmap, marker='o')
    
    x_min = np.min(az)
    x_max = np.max(az)
    #x = np.arange(x_min-15,x_max+15,1)
    #y_cos = np.abs(np.cos(x*np.pi/180.))
    #y_sin = np.abs(np.sin(x*np.pi/180.))
    #pylab.plot(x,y_cos,'k--',linewidth=0.8,label=r'$|\cos(\Delta PA)|$')
    #pylab.plot(x,y_sin,'r--',linewidth=0.8,label=r'$|\sin(\Delta PA)|$')
    
    pylab.grid(True, which='both', linestyle=':', alpha=0.6)
    pylab.xticks(fontsize=8)
    pylab.yticks(fontsize=8)
    pylab.xlabel('scan azimuth [deg]', fontsize=10)
    pylab.ylabel('scan elevation [deg]', fontsize=10)
    #if exper_label is not None:
    #    pylab.title('Test for polarization swap, experiment '+ff.exp_name+' ('+exper_label+'), baseline '+bline+', '+str(len(scan_dPar))+' scans',fontsize=12)
    #else:
    pylab.title('Amplitude dependence on azimuth/elevation for station '+ref_station+' in experiment '+exp_name,fontsize=10)
    
    pylab.ylim(-5, np.max(el)+10)
    pylab.xlim(x_min-15,x_max+15)
    
    cax, _ = mcbar.make_axes(ax, location='right', anchor=(0.5, -0.2), aspect=30 )
    cbar = mcbar.ColorbarBase(cax, cmap=cmap, norm=normalize, orientation='vertical')
    cbar.ax.set_xlabel('amplitude')
    
    #pylab.legend(loc='upper left',fancybox=True,prop={'size':9})
    #pylab.legend(bbox_to_anchor=(0.43,0.35),fancybox=True,prop={'size':11},framealpha=0.8)
    
    pylab.savefig(plot_name + '.png', bbox_inches='tight')
    pylab.close()
    



if __name__ == '__main__':          # official entry point
    main()
    sys.exit(0)
