#!@PYTHON@

#core imports
from __future__ import print_function
from __future__ import division
from builtins import str
from builtins import range
from past.utils import old_div
import argparse
import sys
import os
import math
import re
from datetime import datetime

#non-core imports
#set the plotting back-end to 'agg' to avoid display
import matplotlib as mpl
mpl.use('Agg')
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.colorbar as mcbar
import matplotlib.cm as cmx
from matplotlib.ticker import MultipleLocator, FormatStrFormatter

#HOPS module imports
import vpal.processing
import vpal.utility
import vpal.fringe_file_manipulation


################################################################################

def main():
    # usage_text = '\n phase_resid.py [options] <control-file> <stations> <pol-product> <experiment-directory>' \
    #              '\n e.g.: phase_resid.py ./cf_GHEVY_ff GHEV I ./'
    # parser = optparse.OptionParser(usage=usage_text)

    parser = argparse.ArgumentParser(
        prog='phase_resid.py', \
        description='''utility for generating baseline time trend plots of channel-by-channel phase residuals''' \
        )

    parser.add_argument('control_file', help='the control file to be applied to all scans')
    parser.add_argument('stations', help='concatenated string of single character codes for all stations to be fringe fit')
    parser.add_argument('pol_product', help='the polarization-product to be fringe fit')
    parser.add_argument('experiment_directory', help='relative path to directory containing experiment data')

    parser.add_argument('-n', '--numproc', type=int, dest='num_proc', help='number of concurrent fourfit jobs to run, default=1', default=1)
    parser.add_argument('-c', '--channels', dest='channels', help='specify the channels to be used, default=abcdefghijklmnopqrstuvwxyzABCDEF.', default='abcdefghijklmnopqrstuvwxyzABCDEF')
    parser.add_argument('-s', '--snr-min', type=float, dest='snr_min', help='set minimum allowed snr threshold, default=30.', default=30)
    parser.add_argument('-d', '--dtec-threshold', type=float, dest='dtec_thresh', help='set maximum allowed difference in dTEC, default=1.', default=1.0)
    parser.add_argument('-q', '--quality-limit', type=int, dest='quality_lower_limit', help='set the lower limit on fringe quality (inclusive), default=3.', default=3)
    parser.add_argument('-p', '--progress', action='store_true', dest='use_progress_ticker', help='monitor process with progress indicator', default=False)
    parser.add_argument('-b', '--begin-scan', dest='begin_scan_limit', help='limit the earliest scan to be used e.g 244-1719', default="000-0000")
    parser.add_argument('-e', '--end-scan', dest='end_scan_limit', help='limit the latest scan to be used, e.g. 244-2345', default="999-9999")
    parser.add_argument('-r', '--remove-outliers', dest='remove_outlier_nsigma', help='remove scans which are n*sigma away from the mean, default=0 (off)', default=0.0 )
    parser.add_argument('-z', '--z-axis', dest='z_axis', help='select color axis, options are: amp, dtec, or none, default=none', default='none' )

    args = parser.parse_args()

    #print('args: ', args)

    control_file = args.control_file
    stations = args.stations
    polprod = args.pol_product
    exp_dir = args.experiment_directory

    abs_exp_dir = os.path.abspath(exp_dir)
    exp_name = os.path.split(os.path.abspath(exp_dir))[1]
    rnsigma = float(args.remove_outlier_nsigma)

    if not os.path.isfile(os.path.abspath(control_file)):
        print("could not find control file: ", control_file)
        sys.exit(1)

    #pol product:
    if polprod not in ['XX', 'YY', 'XY', 'YX', 'I']:
        print("polarization product must be one of: XX, YY, XY, YX, or I")
        sys.exit(1)

    
    #determine all possible baselines
    print('Calculating baselines')

    baseline_list = vpal.processing.construct_valid_baseline_list(abs_exp_dir, stations[0], stations[1:], network_reference_baselines_only=False)

    print('Baselines:', baseline_list)
    
    qcode_list = []
    for q in list(range(args.quality_lower_limit, 10)):
        qcode_list.append( str(q) )

    #needed for plot-naming
    control_file_stripped = re.sub('[/\.]', '', control_file)

    for bline in baseline_list:
        #default output filename
        plot_name = "./phresid_" + bline + "_" + polprod + "_" + control_file_stripped + "_" + exp_name
        if args.z_axis != 'none':
            plot_name += "_" + args.z_axis

        print('Collecting fringe files for baseline',bline)

            
        #need to:
        #(1) collect all of the type_210 phase residuals,
        #(2) apply the snr, and quality code cuts
        #(3) for each channel, insert phase residual values and time stamps into array
        #(4) compute mean phase residual for each channel and remove it
        #(5) create plot for each channel, stamp it with the mean phase, and (possibly color it with a scalar paramter: e.g. amp, dtec, etc)

        ################################################################################
        #collect/compute fringe files, and apply cuts
        set_commands = "set gen_cf_record true"
        ff_list = vpal.processing.load_and_batch_fourfit( \
            os.path.abspath(exp_dir), bline[0], bline[1], os.path.abspath(control_file), set_commands, \
            num_processes=args.num_proc, start_scan_limit=args.begin_scan_limit, \
            stop_scan_limit=args.end_scan_limit, pol_products=[polprod], use_progress_ticker=args.use_progress_ticker \
        )

        print("n fringe files  =", str(len(ff_list)))
        
        #apply cuts
        filter_list = []
        filter_list.append( vpal.utility.DiscreteQuantityFilter("quality", qcode_list) )
        filter_list.append( vpal.utility.ContinuousQuantityFilter("snr", args.snr_min, 500) )
        vpal.utility.combined_filter(ff_list, filter_list)

        if len(ff_list) == 0:
            print("Error: no fringe files available after cuts, skipping baseline: ", bline)

        else:

            #loop over fringe files and collect the phase residuals
            phase_residuals = list()

            #invert, unwrap, remove mean phase, and clamp to [-180, 180)
            for ff in ff_list:
                phresid = vpal.fringe_file_manipulation.PhaseResidualData()
                phresid.extract(ff.filename)
                if phresid.is_valid is True:
                    phase_index = []
                    phase_list_proxy = []
                    for ch,ph in list(phresid.phase_residuals.items()):
                        phase_index.append(ch)
                        phase_list_proxy.append(ph)
                    phase_list_proxy = [-1.0*(old_div(math.pi,180.0))*x for x in phase_list_proxy] #negate and convert to radians
                    phase_list_proxy = np.unwrap(phase_list_proxy) #arguments must be in radians
                    phase_list_proxy = [(old_div(180.0,math.pi))*x for x in phase_list_proxy] #convert back to degrees
                    mean_phase = scipy.stats.circmean( np.asarray(phase_list_proxy), high=180.0, low=-180.0) #compute circular mean phase
                    phase_list_proxy = [ (x - mean_phase) for x in phase_list_proxy] #subtract off the mean
                    for i in list(range(0,len(phase_index))):
                        ch = phase_index[i]
                        limited_phase = vpal.utility.limit_periodic_quantity_to_range(phase_list_proxy[i], low_value=-180.0, high_value=180.0)
                        phresid.phase_residuals[ch] = limited_phase
                    phase_residuals.append(phresid)

            #figure out the earliest day in the collection
            if len(phase_residuals) == 0:
                print("Error: could not find phase residuals")
                sys.exit(1)

            first_day = int( phase_residuals[0].scan_name.split('-')[0] )
            for phr in phase_residuals:
                scn = phr.scan_name
                day = int( scn.split('-')[0] )
                if day < first_day:
                    first_day = day

            #construct dict of channel data and compute the mean phase for each channel
            channel_residuals = dict()
            channel_phase_list = dict()
            channel_mean_phase = dict()
            channel_stddev = dict()

            z_list = []

            print("Processing phase residuals for baseline: ", bline )
            for ch in args.channels:
                channel_residuals[ch] = list()
                channel_phase_list[ch] = list()
                for phr in phase_residuals:
                    if ch in phr.phase_residuals:
                        channel_phase_list[ch].append(phr.phase_residuals[ch])
                        if abs(phr.phase_residuals[ch]) > 180:
                            print("Scan: ", phr.scan_name, "has channel phase of: ", phr.phase_residuals[ch] )
                if len(channel_phase_list[ch]) != 0:
                    channel_mean_phase[ch] = scipy.stats.circmean( np.asarray(channel_phase_list[ch]), high=180.0, low=-180.0) #compute circular mean phase
                    channel_stddev[ch] = scipy.stats.circstd( np.asarray(channel_phase_list[ch]), high=180, low=-180) #computer circular std dev.
                    print( "(mean, std. dev) phase for channel: ", ch, " = ", channel_mean_phase[ch], channel_stddev[ch])
                for phr in phase_residuals:
                    if ch in phr.phase_residuals and ch in channel_mean_phase:
                        scan = phr.scan_name
                        day = int(scan.split('-')[0])
                        hour_min = scan.split('-')[1]
                        hr = hour_min[0:2]
                        mn = hour_min[2:4]
                        time = (day - first_day)*24.0 + float(hr) + old_div(float(mn),60.0)
                        #color axis data
                        z = 0.0
                        if args.z_axis == 'dtec':
                            z = phr.dtec
                        if args.z_axis == 'amp':
                            z = phr.channel_phasor_amplitudes[ch]
                        if rnsigma == 0.0:
                            channel_residuals[ch].append( (time, phr.phase_residuals[ch] - channel_mean_phase[ch], z ) ) #could add other info to the tuple (amplitude, dtec, etc)
                            z_list.append(z)
                        else:
                            #strip outliers
                            if abs(phr.phase_residuals[ch] - channel_mean_phase[ch]) < rnsigma*channel_stddev[ch]:
                                channel_residuals[ch].append( (time, phr.phase_residuals[ch] - channel_mean_phase[ch], z ) )
                                z_list.append(z)


            #plot the channel phase vs. time for each channel
            auto_fig = plt.figure(figsize=(8.5,11))
            ax = auto_fig.add_subplot(111)    # The big subplot

            # Turn off axis lines and ticks of the big subplot
            ax.spines['top'].set_color('none')
            ax.spines['bottom'].set_color('none')
            ax.spines['left'].set_color('none')
            ax.spines['right'].set_color('none')
            ax.tick_params(labelcolor='w', top='off', bottom='off', left='off', right='off')

            # Set common labels
            plottitle =  bline + ":" + polprod + " channel phase residuals vs. time for snr > " + str(args.snr_min)
            ax.set_title(plottitle)
            ax.set_ylabel('Phase (degrees)')

            maps = plt.colormaps()
            if 'viridis' not in maps:
                cmap = cmx.get_cmap(maps[0])
            else:
                cmap = cmx.get_cmap('viridis')

            normalize = mcolors.Normalize(vmin=min(z_list), vmax=max(z_list))

            count = 0
            for i in list(range(0, len(args.channels))):
                ch = args.channels[i]
                if len(channel_residuals[ch]) != 0:
                    tmpax = auto_fig.add_subplot( int( math.ceil( old_div(float(len(args.channels)),2.0) ) + 1), 2, count+1)
                    ymin, ymax = tmpax.get_ylim()
                    xmin, xmax = tmpax.get_xlim()
                    # channel_label = ch + ": " + str( round(channel_mean_phase[ch],1) ) + " +/- "+ str( round(channel_stddev[ch],1) )
                    tmpax.text(-0.1, 0.5, ch, fontsize=10, ha='center', transform=tmpax.transAxes)
                    x_val = [x[0] for x in channel_residuals[ch]]
                    y_val = [x[1] for x in channel_residuals[ch]]
                    z_val = [x[2] for x in channel_residuals[ch]]

                    colors = [cmap(normalize(value)) for value in z_val]

                    #determine which formatter to use depending on the min/max range in the phases
                    ph_range = max( abs( max(y_val) ), abs( min(y_val) ) )

                    #tick mark formatting
                    majorFormatter = FormatStrFormatter('%d')
                    majorLocator = MultipleLocator(5)
                    minorLocator = MultipleLocator(1)
                    if ph_range >= 15:
                        majorLocator = MultipleLocator(15)
                        minorLocator = MultipleLocator(5)
                    if ph_range > 45:
                        majorLocator = MultipleLocator(30)
                        minorLocator = MultipleLocator(10)
                    if ph_range > 90:
                        majorLocator = MultipleLocator(45)
                        minorLocator = MultipleLocator(15)
                    if ph_range > 135:
                        majorLocator = MultipleLocator(90)
                        minorLocator = MultipleLocator(30)

                    count += 1
                    tmpax.yaxis.set_major_locator(majorLocator)
                    tmpax.yaxis.set_major_formatter(majorFormatter)
                    # for the minor ticks, use no labels; default NullFormatter
                    tmpax.yaxis.set_minor_locator(minorLocator)
                    tmpax.set_xlabel('Time (hours) since ' + str(first_day).zfill(3) + "-00:00")
                    tmpax.tick_params(axis='both', which='both', direction='in', labelsize=6)
                    if args.z_axis == 'dtec' or args.z_axis == 'amp':
                        tmpax.scatter(x_val, y_val, s=1, c=colors)
                    else:
                        tmpax.scatter(x_val, y_val, s=1, c='r')

            auto_fig.subplots_adjust(hspace=0.1)

            if args.z_axis == 'dtec' or args.z_axis == 'amp':
                cax, _ = mcbar.make_axes(ax, location='bottom', anchor=(0.5, -0.2), aspect=30 )
                cbar = mcbar.ColorbarBase(cax, cmap=cmap, norm=normalize, orientation='horizontal')
                cbar.ax.set_xlabel(args.z_axis)

            auto_fig_name = plot_name + ".png"
            print("generating plot: ", auto_fig_name)
            auto_fig.savefig(auto_fig_name, bbox_inches='tight')
            #plt.ion()
            #plt.show()
            plt.close(auto_fig)
            print('Done\n')


if __name__ == '__main__':          # official entry point
    main()
    sys.exit(0)
