#!/usr/bin/python

#core imports
from __future__ import print_function
from __future__ import division
from builtins import zip
from builtins import range
from past.utils import old_div
import argparse
import sys
import os
import math

#non-core imports
import numpy as np
import scipy.stats
#set the plotting back-end to 'agg' to avoid display
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cmx
import matplotlib.gridspec as gridspec
import matplotlib.mlab as mlab
import matplotlib.ticker as ticker
import matplotlib.patches as mpatches
import matplotlib.mlab as mlab


#HOPS module imports
import vpal.fourphase_lib
import vpal.ffres2pcp_lib
from vpal.utility import tabulate as tabulate

def read_fourphase_report(report_file):
    exp_report = vpal.fourphase_lib.ExperimentReportData()
    exp_report.load_report(report_file)
    return exp_report

def read_ffres2pcp_report(report_file):
    exp_report = vpal.ffres2pcp_lib.ExperimentReportData()
    exp_report.load_report(report_file)
    return exp_report


def process_fourphase_text(exp_report):
    print("--------------------------------")
    print("fourphase configuration summary:")
    print("--------------------------------")
    config_table = []
    config_header_line = ["name: ", "value: "]
    config_table.append(config_header_line)
    for key, val in list(vars(exp_report.config_obj).items()):
        config_line = [key, val]
        config_table.append(config_line)
    vpal.utility.print_table(config_table, n_digits=3)
    print("-----------------------")
    print("fourphase data summary:")
    print("-----------------------")
    data_table = []
    header_line = ["station", "n_total", "n_used", "n_cut", "mean_delay (ns)", "delay_error (ns)", "median_delay(ns)", "median abs deviation (ns)", "mean_phase (deg)", "phase_error (deg)"]
    data_table.append(header_line)
    micro_to_nano = 1000.0
    for key, obj in list(exp_report.generated_station_offsets.items()):
        data_line = []
        data_line.append(obj.station_id)
        data_line.append(obj.get_n_total_entries())
        data_line.append(obj.get_n_used_entries())
        data_line.append(obj.get_n_cut_entries())
        data_line.append(obj.get_delay_offset_mean()*micro_to_nano)
        data_line.append(obj.get_delay_offset_error()*micro_to_nano)
        data_line.append(obj.get_delay_offset_median()*micro_to_nano)
        data_line.append(obj.get_delay_offset_mad()*micro_to_nano)
        data_line.append(obj.get_phase_offset_mean())
        data_line.append(obj.get_phase_offset_error())
        data_table.append(data_line)
    vpal.utility.print_table(data_table, n_digits=3)


################################################################################

def process_fourphase_plot(exp_report):

    #in this function we are going to histogram the y-x phase/delay offsets
    #for each station, annotaed with the mean, median, MAD, gaussian fit,
    #and sigma-cut-lines (if used)

    sigma_cut_factor = exp_report.config_obj.sigma_cut_factor

    for st, val in list(exp_report.generated_station_offsets.items()):
        micro_to_nano = 1000.0
        delays = [x*micro_to_nano for x in val.get_all_delay_offset_values()]
        phases = val.get_all_phase_offset_values()

        if len(delays) != 0 and len(phases) != 0:
            auto_fig = plt.figure(figsize=(7,8))

            ax = auto_fig.add_subplot(211)
            ax.set_ylabel('Number of scans')
            ax.set_xlabel('Y-X delay offset (ns) for station: ' + st, labelpad=5)
            plt.title("Histogram of Y-X delay offsets for station: " + st)

            duplim = (val.get_delay_offset_median() + 5*val.get_delay_offset_mad())*micro_to_nano
            dlowlim = (val.get_delay_offset_median() - 5*val.get_delay_offset_mad())*micro_to_nano

            ndbins = 50
            n, bins, patches = plt.hist(delays, ndbins, range=[dlowlim, duplim], facecolor='b', alpha=0.75)
            #indicate and annotate the central estimators
            plt.axvline(val.get_delay_offset_mean()*micro_to_nano, color='k', linestyle='solid', linewidth=1, label='mean w/ cuts') #mean of cut data
            plt.axvline(np.mean(delays), color='r', linestyle='dashed', linewidth=1, label='mean w/o cuts\n(not all outliers shown)') #mean of uncut data
            plt.legend()

            #plot a guassian about the central estimator
            dval = np.linspace(dlowlim, duplim,100)
            dx = old_div((duplim - dlowlim),float(ndbins))
            scale = len(delays)*dx
            plt.plot(dval, scale*scipy.stats.norm.pdf(dval, val.get_delay_offset_mean()*micro_to_nano, val.get_delay_offset_error()*micro_to_nano))

            ax2 = auto_fig.add_subplot(212)
            ax2.set_ylabel('Number of scans')
            ax2.set_xlabel('Y-X phase offset (degrees) for station: ' + st, labelpad=5)
            plt.title("Histogram of Y-X phase offsets for station: " + st)

            phuplim = min(val.get_phase_offset_mean() + 5*val.get_phase_offset_error(), 180.0)
            phlowlim = max(val.get_phase_offset_mean() - 5*val.get_phase_offset_error(), -180.0)

            nphbins = 50
            n2, bins2, patches2 = plt.hist(phases, nphbins, range=[phlowlim, phuplim], facecolor='g', alpha=0.75)
            plt.axvline(val.get_phase_offset_mean(), color='k', linestyle='solid', linewidth=1) #mean of cut data
            plt.axvline(scipy.stats.circmean( np.asarray(phases), low=-180.0, high=180.0), color='r', linestyle='dashed', linewidth=1)  #mean of no-cut data

            #plot a gaussian about the central estimator
            phval = np.linspace(phlowlim, phuplim,100)
            dph = old_div((phuplim - phlowlim),float(nphbins))
            scale = len(phases)*dph
            plt.plot(phval, scale*scipy.stats.norm.pdf(phval, val.get_phase_offset_mean(), val.get_phase_offset_error()))

            plt.tight_layout()
            auto_fig_name = "./yx_delay_phase_offsets_" + st + ".png"
            auto_fig.savefig(auto_fig_name)
            plt.close(auto_fig)
        else:
            print("Error: no data for station: " + st + " available in report.")


################################################################################

def process_ffres2pcp_text(exp_report):
    print("--------------------------------")
    print("ffres2pcp configuration summary:")
    print("--------------------------------")
    config_table = []
    config_header_line = ["name: ", "value: "]
    config_table.append(config_header_line)
    for key, val in list(vars(exp_report.config_obj).items()):
        config_line = [key, val]
        config_table.append(config_line)
    vpal.utility.print_table(config_table, n_digits=3)

    print("-----------------------")
    print("ffres2pcp data summary:")
    print("-----------------------")

    #get the station/pol data
    apriori_pc_phases = exp_report.apriori_sspcp
    generated_pc_phases = exp_report.generated_sspcp
    ap_keys = list(apriori_pc_phases.keys())
    gen_keys = list(generated_pc_phases.keys())
    all_keys = set(ap_keys).union( set(gen_keys) )

    pc_info = dict()
    for key in all_keys:
        if key in apriori_pc_phases and key in generated_pc_phases:
            ap_values = apriori_pc_phases[key].get_mean_phases()
            gen_values = generated_pc_phases[key].get_mean_phases()
            err_values = generated_pc_phases[key].get_stddev_phases()
            final_pc_values = vpal.ffres2pcp_lib.sum_mean_pc_phases(apriori_pc_phases[key], generated_pc_phases[key])
            pc_info[key] = (ap_values, gen_values, final_pc_values, err_values)
        elif key in apriori_pc_phases and key not in generated_pc_phases:
            ap_values = apriori_pc_phases[key].get_mean_phases()
            final_pc_values = apriori_pc_phases[key].get_mean_phases()
            gen_values = dict()
            err_values = dict()
            pc_info[key] = (ap_values, gen_values, final_pc_values, err_values)
        elif key not in apriori_pc_phases and key in generated_pc_phases:
            ap_values = dict()
            gen_values = generated_pc_phases[key].get_mean_phases()
            err_values = generated_pc_phases[key].get_stddev_phases()
            final_pc_values = generated_pc_phases[key].get_mean_phases()
            pc_info[key] = (ap_values, gen_values, final_pc_values, err_values)

    #for each station,pol we form a table of "a-priori pc_phases, delta pc_phases, new pc_phases and pc_phases errors"
    result_table = dict()
    for key, val in list(pc_info.items()):
        st, pol = key.split('_')
        print("-----------------------")
        print("station: ", st, ", pol: ", pol)
        print("-----------------------")

        ch_defs = vpal.ffres2pcp_lib.DefaultChannelDefines()

        table = []
        header_line = []; header_line.append("channel: ")
        apc = []; apc.append("a-priori pc_phase: ")
        gpc = []; gpc.append("delta pc_phase: ")
        fpc = []; fpc.append("result pc_phase: ")
        epc = []; epc.append("error: ")
        for ch in ch_defs.vgos:
            header_line.append(ch)
            if ch in pc_info[key][0]: #apriori
                apc.append( round(pc_info[key][0][ch], 1) )
            else:
                apc.append(0.0)
            if ch in pc_info[key][1]: #gen/delta
                gpc.append( round(pc_info[key][1][ch], 1) )
            else:
                gpc.append(0.0)
            if ch in pc_info[key][2]: #final
                fpc.append( round(pc_info[key][2][ch], 1) )
            else:
                fpc.append(0.0)
            if ch in pc_info[key][3]: #error
                epc.append( round(pc_info[key][3][ch], 1) )
            else:
                epc.append(0.0)

        table.append(apc)
        table.append(gpc)
        table.append(fpc)
        table.append(epc)
        print(tabulate(table, headers=header_line))

        result_table[key] = table

    return result_table


################################################################################

def process_ffres2pcp_plot(report_data, result_table):

    #we want to create a plot of the pc_phases offset for each station/pol
    #first get a list of all the stations:
    all_stations = set([ x.split('_')[0] for x in list(result_table.keys()) ])
    all_pols = set([ x.split('_')[1] for x in list(result_table.keys()) ])

    tmp_channel_freqs = list()
    for key,val in list(report_data.channel_freqs.items()):
        tmp_channel_freqs.append( (key,val) )
    channel_freqs = sorted( tmp_channel_freqs, key=vpal.ffres2pcp_lib.channel_sort_key)

    ch_defs = vpal.ffres2pcp_lib.DefaultChannelDefines()
    if len(channel_freqs) != len(ch_defs.vgos):
        #print("Error: report file does not contain the requisite number (32) of channel definitions")
        print("Warning: report file does not contain the requisite number (32) of channel definitions")
        #sys.exit(1)

    Hz_to_GHz = 1e-9
    chan_codes = [ x[0] for x in channel_freqs]
    chan_freqs = [ x[1]*Hz_to_GHz for x in channel_freqs]

    freq_buffer = 0.01
    for st in all_stations:
        keys_to_plot = []
        for p in all_pols:
            key_to_plot = st + '_' + p
            #if we have this station-pol combo, grab the data and plot it
            if key_to_plot in result_table:
                pcp_data = result_table[key_to_plot]
                apc = pcp_data[0][1:] #strip first element, as it is a string label
                fpc = pcp_data[2][1:]
                epc = pcp_data[3][1:]

                limits_pc = [10.0, -10.0]
                limits_pc.extend(fpc)
                limits_pc.extend(apc)
                limits_pc.extend( [x + y for x, y in list(zip(fpc, epc)) ] )
                limits_pc.extend( [x - y for x, y in list(zip(fpc, epc)) ] )

                min_pc = min(limits_pc)
                max_pc = max(limits_pc)

                #round max/min to nearest 10, in direction away from zero
                min_pc = min_pc + (old_div(min_pc,abs(min_pc))) * ( (-1*abs(min_pc))%10 )
                max_pc = max_pc + (old_div(max_pc,abs(max_pc))) * ( (-1*abs(max_pc))%10 )

                deg_step = 5.0
                if( abs(max_pc - min_pc) >= 100 ):
                    deg_step = 10.0
                if( abs(max_pc - min_pc) >= 200 ):
                    deg_step = 30.0
                n_div = int( math.floor( old_div(abs(max_pc - min_pc), deg_step)) )

                deg_ticks = []
                for i in list(range(0,n_div)):
                    deg_ticks.append(min_pc + i*deg_step)

                auto_fig = plt.figure(figsize=(14,8))
                bigax = auto_fig.add_subplot(111)
                bigax.spines['top'].set_color('none')
                bigax.spines['bottom'].set_color('none')
                bigax.spines['left'].set_color('none')
                bigax.spines['right'].set_color('none')
                bigax.tick_params(top='off', bottom='off', left='on', right='off')
                bigax.set_ylabel('Phase Correction (deg)')
                bigax.set_xlabel('Channel Frequency (GHz)', labelpad=20)
                plt.title(p + '-pol phase corrections for station (' + st + ') by channel', y=1.11 )
                bigax.set_yticks(deg_ticks)
                bigax.set_ylim(min_pc, max_pc)
                bigax.set_xticks([])
                bigax.xaxis.set_ticks_position('none')
                bigax2 = bigax.twiny()
                bigax2.set_xlim(bigax.get_xlim())
                bigax2.set_xlabel('Channel Index', labelpad=20)
                bigax2.set_xticklabels([])
                bigax2.spines['top'].set_color('none')
                bigax2.spines['bottom'].set_color('none')
                bigax2.spines['left'].set_color('none')
                bigax2.spines['right'].set_color('none')
                for band in [1,2,3,4]:
                    ax = auto_fig.add_subplot(1,4,band)
                    aprioripc = plt.plot(chan_freqs[(band-1)*8:band*8], apc[(band-1)*8:band*8], '-ob')
                    ax = plt.gca()
                    ax.set_ylim(min_pc, max_pc)
                    ax.set_xlim(chan_freqs[(band-1)*8]-freq_buffer, chan_freqs[band*8-1]+freq_buffer)
                    ax.set_yticks(deg_ticks)
                    ax.grid(True)
                    ax.tick_params(top='off', bottom='on', left='on', right='off')
                    plt.setp(ax.get_yticklabels(), visible=False)
                    genpc = plt.errorbar(x=chan_freqs[(band-1)*8:band*8], y=fpc[(band-1)*8:band*8], yerr=epc[(band-1)*8:band*8], fmt='-rs', capsize=3)
                    ax = plt.gca()
                    ax.set_ylim(min_pc, max_pc)
                    ax.set_xlim(chan_freqs[(band-1)*8]-freq_buffer, chan_freqs[band*8-1]+freq_buffer)
                    ax.set_yticks(deg_ticks)
                    ax.grid(True)
                    ax.tick_params(top='off', bottom='on', left='on', right='off')
                    plt.setp(ax.get_yticklabels(), visible=False)
                    if band == 1:
                        plt.legend([aprioripc[0], genpc[0]], ["a priori pc_phases", "estimated pc_phases"])
                    #add channel code labels
                    ax2 = ax.twiny()
                    ax2.set_yticklabels([])
                    ax2.set_xlim(chan_freqs[(band-1)*8]-freq_buffer, chan_freqs[band*8-1]+freq_buffer)
                    ax2.set_xticks(chan_freqs[(band-1)*8:band*8])
                    ax2.set_xticklabels(chan_codes[(band-1)*8:band*8])
                    plt.setp(ax2.get_yticklabels(), visible=False)
                auto_fig_name = "./" + key_to_plot + ".png"
                auto_fig.savefig(auto_fig_name)
                plt.close(auto_fig)

    #now  for each station we we want to go through the baseline-scan collections
    #and plot the scans in the (min_snr, dtec_mdev) plane so we can see
    #which ones were chosen for the pc_phases calculation
    all_stations = report_data.config_obj.target_stations + report_data.config_obj.network_reference_station
    for st in all_stations:
        for p in all_pols:
            #first find all the scan/baselines used for this station/pol
            key_to_plot = st + '_' + p
            if key_to_plot in report_data.generated_sspcp:
                scan_bl_used = report_data.generated_sspcp[key_to_plot].scan_baselines_used
                all_min_snr_list = []
                all_dtec_mdev_list = []
                selected_min_snr_list = []
                selected_dtec_mdev_list = []
                for blc in report_data.all_blc:
                    if st in blc.baseline:
                        scan_bl = blc.scan_name + "_" + blc.baseline
                        if blc.dtec_mdev < report_data.config_obj.dtec_tolerance:
                            all_min_snr_list.append(float(blc.min_snr))
                            all_dtec_mdev_list.append(float(blc.dtec_mdev))
                            if scan_bl in scan_bl_used:
                                selected_min_snr_list.append(float(blc.min_snr))
                                selected_dtec_mdev_list.append(float(blc.dtec_mdev))

                auto_fig = plt.figure(figsize=(6,6))
                ax = auto_fig.add_subplot(111)
                ax.set_ylabel('Maximum dTEC deviation from mean (TECU)')
                ax.set_xlabel('Minimum SNR over all pol-products', labelpad=20)
                plt.title("Scan selection for station: " + st + ", pol: " + p)
                all_scans = plt.plot(all_min_snr_list, all_dtec_mdev_list, 'k.')
                selected_scans = plt.plot(selected_min_snr_list, selected_dtec_mdev_list, 'b*')
                plt.legend([all_scans[0], selected_scans[0]], ["scans considered", "selected scans"])

                auto_fig_name = "./scan_selection_" + key_to_plot + ".png"
                auto_fig.savefig(auto_fig_name)
                plt.close(auto_fig)



################################################################################

def main():
    parser = argparse.ArgumentParser(
        prog='summarize_report.py', \
        description='''utility for generating text and plot summaries of ffres2pcp.py/fourphase.py report files''' \
        )

    parser.add_argument('report_file', help='report file name')
    parser.add_argument('-t', '--text-only', action='store_true', dest='text_mode', help='do not plot, issue text report only.', default=False)

    args = parser.parse_args()
    report_file = args.report_file
    report_file = os.path.abspath(report_file)

    if 'ffres2pcp' in report_file:
        mode = 'ffres2pcp'
    elif 'fourphase' in report_file:
        mode = 'fourphase'
    else:
        print('report format unrecognized: file name must indicate whether it is an ffres2pcp or fourphase report.')
        return 1

    if mode == 'ffres2pcp':
        report_data = read_ffres2pcp_report(report_file)
        result_table = process_ffres2pcp_text(report_data)
        if args.text_mode is False:
             process_ffres2pcp_plot(report_data, result_table)

    if mode == 'fourphase':
        report_data = read_fourphase_report(report_file)
        process_fourphase_text(report_data)
        if args.text_mode is False:
            process_fourphase_plot(report_data)

    return 0


if __name__ == '__main__':          # official entry point
    main()
    sys.exit(0)
