#!/usr/bin/python

#core imports
from __future__ import print_function
from __future__ import division
from builtins import str
from builtins import range
from past.utils import old_div
import argparse
import sys
import os
import math
import re

#non-core imports
#set the plotting back-end to 'agg' to avoid display
import matplotlib as mpl
mpl.use('Agg')
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.colorbar as mcbar
import matplotlib.cm as cmx
from matplotlib.ticker import MultipleLocator, FormatStrFormatter

#HOPS module imports
import vpal.processing
import vpal.utility
import vpal.fringe_file_manipulation

import hopstestb as ht

################################################################################

def main():
    # usage_text = '\n parallactic_plots.py [options] <control-file> <stations> <experiment-directory>' \
    #              '\n e.g.: parallactic_plots.py ./cf_GHEVY_ff GHEV ./'
    # parser = optparse.OptionParser(usage=usage_text)

    parser = argparse.ArgumentParser(
        prog='parallactic_plots.py', \
        description='''utility for plotting parallactic angle dependence of pol-product amplitudes''' \
        )

    parser.add_argument('control_file', help='the control file to be applied to all scans')
    parser.add_argument('stations', help='concatenated string of single character codes for all stations to be fringe fit')
    #parser.add_argument('pol_product', help='the polarization-product to be fringe fit')
    parser.add_argument('experiment_directory', help='relative path to directory containing experiment data')

    parser.add_argument('-n', '--numproc', type=int, dest='num_proc', help='number of concurrent fourfit jobs to run, default=1', default=1)
    parser.add_argument('-c', '--channels', dest='channels', help='specify the channels to be used, default=abcdefghijklmnopqrstuvwxyzABCDEF.', default='abcdefghijklmnopqrstuvwxyzABCDEF')
    parser.add_argument('-s', '--snr-min', type=float, dest='snr_min', help='set minimum allowed snr threshold, default=30.', default=30)
    parser.add_argument('-d', '--dtec-threshold', type=float, dest='dtec_thresh', help='set maximum allowed difference in dTEC, default=1.', default=1.0)
    parser.add_argument('-q', '--quality-limit', type=int, dest='quality_lower_limit', help='set the lower limit on fringe quality (inclusive), default=3.', default=3)
    parser.add_argument('-p', '--progress', action='store_true', dest='use_progress_ticker', help='monitor process with progress indicator', default=False)
    parser.add_argument('-b', '--begin-scan', dest='begin_scan_limit', help='limit the earliest scan to be used e.g 244-1719', default="000-0000")
    parser.add_argument('-e', '--end-scan', dest='end_scan_limit', help='limit the latest scan to be used, e.g. 244-2345', default="999-9999")
    parser.add_argument('-r', '--remove-outliers', dest='remove_outlier_nsigma', help='remove scans which are n*sigma away from the mean, default=0 (off)', default=0.0 )
    parser.add_argument('-z', '--z-axis', dest='z_axis', help='select color axis, options are: amp, dtec, or none, default=none', default='none' )

    args = parser.parse_args()

    #print('args: ', args)

    control_file = args.control_file
    stations = args.stations
    #polprod = args.pol_product
    exp_dir = args.experiment_directory

    abs_exp_dir = os.path.abspath(exp_dir)
    exp_name = os.path.split(os.path.abspath(exp_dir))[1]
    rnsigma = float(args.remove_outlier_nsigma)

    #determine all possible baselines
    baseline_list = vpal.processing.construct_valid_baseline_list(abs_exp_dir, stations[0], stations[1:], network_reference_baselines_only=False)

    qcode_list = []
    for q in list(range(args.quality_lower_limit, 10)):
        qcode_list.append( str(q) )

    #needed for plot-naming
    control_file_stripped = re.sub('[/\.]', '', control_file)

    # bline is a two-element string of [reference_station, other_station]
    for bline in baseline_list:
        #default output filename
        plot_name = "./parallactic_" + bline + "_" + control_file_stripped + "_" + exp_name
        if args.z_axis != 'none':
            plot_name += "_" + args.z_axis

        if not os.path.isfile(os.path.abspath(control_file)):
            print("could not find control file: ", control_file)
            sys.exit(1)

        #pol product:
        #if polprod not in ['XX', 'YY', 'XY', 'YX', 'I']:
        #    print("polarization product must be one of: XX, YY, XY, YX, or I")
        #    sys.exit(1)

        #need to:
        #(1) collect the amplitudes for each pol-product from each scan
        #(2) apply the snr, and quality code cuts
        #(3) build array for each pol-product of parallactic angle and (normalized) amplitude
        #(5) create plot with pol-products for each possible baseline

        ################################################################################
        #collect/compute fringe files, and apply cuts
        set_commands = "set gen_cf_record true"
        ff_list = vpal.processing.load_and_batch_fourfit( \
            os.path.abspath(exp_dir), bline[0], bline[1], os.path.abspath(control_file), set_commands, \
            num_processes=args.num_proc, start_scan_limit=args.begin_scan_limit, \
            stop_scan_limit=args.end_scan_limit, pol_products=[polprod], use_progress_ticker=args.use_progress_ticker \
        )

        print("n fringe files  =", str(len(ff_list)))
        print(ff_list)

        sys.exit()
        
        #apply cuts
        filter_list = []
        filter_list.append( vpal.utility.DiscreteQuantityFilter("quality", qcode_list) )
        filter_list.append( vpal.utility.ContinuousQuantityFilter("snr", args.snr_min, 500) )
        vpal.utility.combined_filter(ff_list, filter_list)

        if len(ff_list) == 0:
            print("Error: no fringe files available after cuts, skipping baseline: ", bline)

        else:

            # how to extract amplitude and parallactic angle?
            
            #loop over fringe files and collect the phase residuals
            phase_residuals = list()

            #invert, unwrap, remove mean phase, and clamp to [-180, 180)
            for ff in ff_list:

                # find the root files

                # the mk4 side id is the bline element
                recursive_find_root_files(base_directory, sort_list=True, exclude_list=None, max_depth=None)
                
                # work out the station data file
                station_data_file = os.path.join( os.path.dirname(self.ovex_data_file), mk4_site_id + ".." + self.root_suffix)

                # get the parallactic angles and calculate dPar

                # read the fringe file for each pol product and get the scan amplitude
                
                phresid = vpal.fringe_file_manipulation.PhaseResidualData()
                phresid.extract(ff.filename)
                if phresid.is_valid is True:
                    phase_index = []
                    phase_list_proxy = []
                    for ch,ph in list(phresid.phase_residuals.items()):
                        phase_index.append(ch)
                        phase_list_proxy.append(ph)
                    phase_list_proxy = [-1.0*(old_div(math.pi,180.0))*x for x in phase_list_proxy] #negate and convert to radians
                    phase_list_proxy = np.unwrap(phase_list_proxy) #arguments must be in radians
                    phase_list_proxy = [(old_div(180.0,math.pi))*x for x in phase_list_proxy] #convert back to degrees
                    mean_phase = scipy.stats.circmean( np.asarray(phase_list_proxy), high=180.0, low=-180.0) #compute circular mean phase
                    phase_list_proxy = [ (x - mean_phase) for x in phase_list_proxy] #subtract off the mean
                    for i in list(range(0,len(phase_index))):
                        ch = phase_index[i]
                        limited_phase = vpal.utility.limit_periodic_quantity_to_range(phase_list_proxy[i], low_value=-180.0, high_value=180.0)
                        phresid.phase_residuals[ch] = limited_phase
                    phase_residuals.append(phresid)

            #figure out the earliest day in the collection
            if len(phase_residuals) == 0:
                print("Error: could not find phase residuals")
                sys.exit(1)

            first_day = int( phase_residuals[0].scan_name.split('-')[0] )
            for phr in phase_residuals:
                scn = phr.scan_name
                day = int( scn.split('-')[0] )
                if day < first_day:
                    first_day = day

            #construct dict of channel data and compute the mean phase for each channel
            channel_residuals = dict()
            channel_phase_list = dict()
            channel_mean_phase = dict()
            channel_stddev = dict()

            z_list = []

            print("Processing phase residuals for baseline: ", bline )
            for ch in args.channels:
                channel_residuals[ch] = list()
                channel_phase_list[ch] = list()
                for phr in phase_residuals:
                    if ch in phr.phase_residuals:
                        channel_phase_list[ch].append(phr.phase_residuals[ch])
                        if abs(phr.phase_residuals[ch]) > 180:
                            print("Scan: ", phr.scan_name, "has channel phase of: ", phr.phase_residuals[ch] )
                if len(channel_phase_list[ch]) != 0:
                    channel_mean_phase[ch] = scipy.stats.circmean( np.asarray(channel_phase_list[ch]), high=180.0, low=-180.0) #compute circular mean phase
                    channel_stddev[ch] = scipy.stats.circstd( np.asarray(channel_phase_list[ch]), high=180, low=-180) #computer circular std dev.
                    print( "(mean, std. dev) phase for channel: ", ch, " = ", channel_mean_phase[ch], channel_stddev[ch])
                for phr in phase_residuals:
                    if ch in phr.phase_residuals and ch in channel_mean_phase:
                        scan = phr.scan_name
                        day = int(scan.split('-')[0])
                        hour_min = scan.split('-')[1]
                        hr = hour_min[0:2]
                        mn = hour_min[2:4]
                        time = (day - first_day)*24.0 + float(hr) + old_div(float(mn),60.0)
                        #color axis data
                        z = 0.0
                        if args.z_axis == 'dtec':
                            z = phr.dtec
                        if args.z_axis == 'amp':
                            z = phr.channel_phasor_amplitudes[ch]
                        if rnsigma == 0.0:
                            channel_residuals[ch].append( (time, phr.phase_residuals[ch] - channel_mean_phase[ch], z ) ) #could add other info to the tuple (amplitude, dtec, etc)
                            z_list.append(z)
                        else:
                            #strip outliers
                            if abs(phr.phase_residuals[ch] - channel_mean_phase[ch]) < rnsigma*channel_stddev[ch]:
                                channel_residuals[ch].append( (time, phr.phase_residuals[ch] - channel_mean_phase[ch], z ) )
                                z_list.append(z)


            #plot the channel phase vs. time for each channel
            auto_fig = plt.figure(figsize=(8.5,11))
            ax = auto_fig.add_subplot(111)    # The big subplot

            # Turn off axis lines and ticks of the big subplot
            ax.spines['top'].set_color('none')
            ax.spines['bottom'].set_color('none')
            ax.spines['left'].set_color('none')
            ax.spines['right'].set_color('none')
            ax.tick_params(labelcolor='w', top='off', bottom='off', left='off', right='off')

            # Set common labels
            plottitle =  bline + ":" + polprod + " channel phase residuals vs. time for snr > " + str(args.snr_min)
            ax.set_title(plottitle)
            ax.set_ylabel('Phase (degrees)')


            maps = plt.colormaps()
            if 'viridis' not in maps:
                cmap = cmx.get_cmap(maps[0])
            else:
                cmap = cmx.get_cmap('viridis')

            normalize = mcolors.Normalize(vmin=min(z_list), vmax=max(z_list))

            count = 0
            for i in list(range(0, len(args.channels))):
                ch = args.channels[i]
                if len(channel_residuals[ch]) != 0:
                    tmpax = auto_fig.add_subplot( int( math.ceil( old_div(float(len(args.channels)),2.0) ) + 1), 2, count+1)
                    ymin, ymax = tmpax.get_ylim()
                    xmin, xmax = tmpax.get_xlim()
                    # channel_label = ch + ": " + str( round(channel_mean_phase[ch],1) ) + " +/- "+ str( round(channel_stddev[ch],1) )
                    tmpax.text(-0.1, 0.5, ch, fontsize=10, ha='center', transform=tmpax.transAxes)
                    x_val = [x[0] for x in channel_residuals[ch]]
                    y_val = [x[1] for x in channel_residuals[ch]]
                    z_val = [x[2] for x in channel_residuals[ch]]

                    colors = [cmap(normalize(value)) for value in z_val]

                    #determine which formatter to use depending on the min/max range in the phases
                    ph_range = max( abs( max(y_val) ), abs( min(y_val) ) )

                    #tick mark formatting
                    majorFormatter = FormatStrFormatter('%d')
                    majorLocator = MultipleLocator(5)
                    minorLocator = MultipleLocator(1)
                    if ph_range >= 15:
                        majorLocator = MultipleLocator(15)
                        minorLocator = MultipleLocator(5)
                    if ph_range > 45:
                        majorLocator = MultipleLocator(30)
                        minorLocator = MultipleLocator(10)
                    if ph_range > 90:
                        majorLocator = MultipleLocator(45)
                        minorLocator = MultipleLocator(15)
                    if ph_range > 135:
                        majorLocator = MultipleLocator(90)
                        minorLocator = MultipleLocator(30)

                    count += 1
                    tmpax.yaxis.set_major_locator(majorLocator)
                    tmpax.yaxis.set_major_formatter(majorFormatter)
                    # for the minor ticks, use no labels; default NullFormatter
                    tmpax.yaxis.set_minor_locator(minorLocator)
                    tmpax.set_xlabel('Time (hours) since ' + str(first_day).zfill(3) + "-00:00")
                    tmpax.tick_params(axis='both', which='both', direction='in', labelsize=6)
                    if args.z_axis == 'dtec' or args.z_axis == 'amp':
                        tmpax.scatter(x_val, y_val, s=1, c=colors)
                    else:
                        tmpax.scatter(x_val, y_val, s=1, c='r')

            auto_fig.subplots_adjust(hspace=0.1)

            if args.z_axis == 'dtec' or args.z_axis == 'amp':
                cax, _ = mcbar.make_axes(ax, location='bottom', anchor=(0.5, -0.2), aspect=30 )
                cbar = mcbar.ColorbarBase(cax, cmap=cmap, norm=normalize, orientation='horizontal')
                cbar.ax.set_xlabel(args.z_axis)

            auto_fig_name = plot_name + ".png"
            print("generating plot: ", auto_fig_name)
            auto_fig.savefig(auto_fig_name, bbox_inches='tight')
            plt.ion()
            plt.show()
            plt.close(auto_fig)



if __name__ == '__main__':          # official entry point
    main()
    sys.exit(0)
