#!/usr/bin/python

'''
select_bandpols.py
'''

help_text = '''
This script selects good band-pols whose rows in the correlation matrix have
medians below the threshold (now 0.5). The rows and columns of bad bandpols are
iteratively removed from the matrix, improving the median values for other
band-pols, until all medians are above the threshold. The correlation matrix,
medians, and multiple correlation coefficients on each iteration (if any) are
saved in text files.

Unlike plot_multcorr.py, the script can work on all the data (option -a) for
one or several stations (like -s E, -s EGY, or -s IEHV).

Generates plots of the band-pols for one experiment on one station, or all the
experiments on one or more stations. Computes medians for the rows (or columns)
of the correlation matrix. Computes the multiple correlation coefficients of
each band-pol with respect to other band-pols.

The plots with with the correlation median below its threshold are marked as
"Rejected".

Arguments:
  -m <threshold>              for correlation median, -1 to 1 (default 0.5).
  -s <station (single) letter code>, like E, G, H ... (or in lower case, e, g, h ...);
  -d <pcc_datfiles directory>       like /data/geodesy/3686/pcc_datfiles
  -o <output directory name>        where .png graphs and .txt logs are saved
  -b    show band-pol plot in X-window
  -c    show x-correlation matrix
  -a    process all available data under directory in -d (like -d /data/geodesy)
        If more stations are given in -s, like -s EY or -s VIGH, only the
        data for those stations are plotted and saved in -o directory.
        If -a is present, -b and -c are ignored (for too many windows would
        be opened). WARNING, this function not completely supported.
  -n    do not create plots and do not save them in .png files
  -h    print this text.

Examples:

(1) Select good band-pols in experiment 3658 from station E. Save the band-pol
plot in directory pltE. Save the correlation matrix plot (-c). Show both plots
on the screen (-b).
The process of successive selection of good band-pols is logged in text file
bandpol_E_3658_VT8204.txt.
The plots are saved as bandpol_E_3658_VT8204.png and xcorrmx_E_3658_VT8204.png:

%run select_bandpols.py -s E -d /data/geodesy/3658/pcc_datfiles_jb/ -o pltE \
                          -b -c

(2) Select good band-pols in all (-a) the experimental data from station V
located under directory /data/geodesy/.
Save the band-pol plots in directory pltV. Save the correlation matrix plots
(-c). The key -b (show plots on the screen) is ignored when -a is used.
The processes of successive selection of good band-pols are
logged in text file bandpol_V_<exp.code>_<exp.name>.txt.
The plots are saved as bandpol_E_<exp.code>_<exp.name>.png and
xcorrmx_E_<exp.code>_<exp.name>.png.
The diagnostic messages are logged in diagnostics_st_E.txt:

%run select_bandpols.py -s V -d /data/geodesy/ -o pltV -b -c -a

(3) Select good band-pols in all (-a) the experimental data from stations
I and Y. Save the band-pol plots in directory pltV. Save the correlation matrix
plots (-c).
The processes of successive selection of good band-pols are
logged in text file bandpol_<station>_<exp.code>_<exp.name>.txt.
The plots are saved as bandpol_<station>_<exp.code>_<exp.name>.png and
xcorrmx_<station>_<exp.code>_<exp.name>.png:
The diagnostic messages are logged in diagnostics_st_IY.txt:

%run select_bandpols.py -s IY -d /data/geodesy/ -o pltIY -c -a
'''


import numpy as np
import numpy.linalg as la
from scipy.signal import medfilt
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import vpal.phasecal
from vpal.phasecal import mult_corr, write_xcorrmx, write_title, write_numbers
import os, sys, glob, copy
import re
import datetime, time
from functools import reduce
import getopt


# threshold_mulcor = 0.9
threshold_median = 0.5

#all_stations = ('E', 'G', 'H', 'I', 'V', 'Y', 'M')
bp_sym = ['AX', 'AY', 'BX', 'BY', 'CX', 'CY', 'DX', 'DY']
nbp_sym = len(bp_sym)
bp_patt = re.compile(r'[A-D][X-Y]')  # Regex pattern to catch bandpols, like AX

fields = [('year', 'i2'),  ('doy', 'i2'),
          ('hour', 'i2'), ('minute', 'i2'), ('second', 'i2'),
          ('phase_midband', 'f8'), ('phase_dc', 'f8'),
          ('delay_ps', 'f8'), ('phase_rmse', 'f8'),
          ('scan', 'S16'), ('source', 'S16'), ('station', 'S2'),
          ('azimuth', 'f8'), ('elevation', 'f8')]


#
# Process command line options
#
if sys.argv[1:] == []: # Print help text and exit if no command line options
    print(help_text)
    raise SystemExit

optlist = getopt.getopt(sys.argv[1:], 'm:s:d:o:abchn')[0]

for opt, val in optlist:
    if opt == '-h':  # Print help text and exit if there is '-h' among options
        print(help_text)
        raise SystemExit

n_datadir = 1
n_station = 1
dirname = ''
station = ''
outdir = ''
nbandpol = 0
show_graph = False
dirname_exists = True
station_data_exist = True
show_xcorrmx = False
process_all = False
no_plots = False

for opt, val in optlist:
    if opt == '-m':
        m_0 = float(val)
        if m_0 < -1.:
            threshold_median = -1.
        elif m_0 > 1.:
            threshold_median = 1.
        else:
            threshold_median = m_0
    if opt == '-s':
        station = val.upper()     # Letter(s) like E or e, G or g etc.
    elif opt == '-d':
        dirname = copy.copy(val)
        if dirname[-1] != '/':
            dirname += '/'
    elif opt == '-o':
        outdir = copy.copy(val)   # Like /data/geodesy/3671/pcc_datfiles_jb
        if outdir[-1] != '/':
            outdir += '/'
        if not os.path.isdir(outdir):
            os.mkdir(outdir)
    elif opt == '-c':
        show_xcorrmx = True
    elif opt == '-b':
        show_graph = True
    elif opt == '-a':
        process_all = True
    elif opt == '-n':
        no_plots = True


print('Median threshold: ' + str(threshold_median))
str_mf = '_m' + '%4.2f' % threshold_median



#
# Write diagnostics file with warnings and error messages
#
fname_warn = outdir + 'diagnostics' + '_st_' + station + str_mf + '.txt'
fwarn = open(fname_warn, 'w')
print('Diagnostics saved in ' + fname_warn)


#
# Write file with selected bands
#
fname_sel = outdir + 'selections' + '_st_' + station + str_mf + '.txt'
fsel = open(fname_sel, 'w')



if station == '':
    print('ERROR: Station must be specified in option -s.')
    station_data_exist = False
else:
    #
    # Leave only unique station letters in the station string
    #
    n_station = 0
    station_temp = ''
    for st in station:
        #if (st in all_stations) and (st not in station_temp):
        if (st not in station_temp):
            station_temp += st
            n_station += 1
        else:
            pass
            #print('ERROR: Station ' + st + ' not in ' + str(all_stations))
            #station_data_exist = False
    station = station_temp

if dirname == '':
    print('ERROR: Data directory must be specified in option -d.')
    station_data_exist = False
elif not os.path.isdir(dirname):
    print('ERROR: Path ' + dirname + ' does not exist.')
    dirname_exists = False

#
# Exit if wrong specifications have been noticed in command line
#
if '' in [dirname, station]:
    raise SystemExit
if False in [dirname_exists,station_data_exist]:
    raise SystemExit




if process_all:
    show_graph = False       # If -a is present, -b is ignored
    show_xcorrmx = False     # If -a is present, -c is ignored
else:
    station = station[0]   # If option '-a' not given, use only first station
    n_station = 1

if process_all:
    #
    # If -a, get list of all the directories with data starting from 4-digit
    # experiment code
    #
    datadir1 = glob.glob(dirname + '????/pcc_datfiles*/')
    datadir = []
    for dd in datadir1:
        if 'datfiles/' in os.path.abspath(dd):
            datadir.append(dd)

    #alternative check if we have been given a single dat directory by itself
    #since this applies if we want to do all stations at once
    if 'datfiles' in os.path.abspath(dirname):
        datadir.append(os.path.abspath(dirname))

    wrong_datadir = []  # Gather here all subdirectories not like /3693/
    for ddir in datadir:
        if re.findall('\/[0-9]{4}\/', os.path.abspath(ddir) ) == []:
            wrong_datadir.append(ddir)

    for wrdd in wrong_datadir:
        datadir.remove(wrdd)   # Remove all subdirectories not like /3693/

    n_datadir = len(datadir)

    if n_datadir == 0:
        print('ERROR: No data directories in ' + dirname)
        raise SystemExit

    datadir.sort()

else:
    #
    # If no -a option, assume data directory is given in -d dirname
    #
    datadir = [copy.copy(dirname)]
    n_datadir = 1


# raise SystemExit

#
# For indication of goodness of the multiple correlation coefficient,
# create rYlGn, a YlGn (Yellow-Green) colormap with reduced dynamic range
#
cmYlGn = plt.cm.get_cmap('YlGn')
rYlGn = ListedColormap(cmYlGn(np.linspace(0.0, 0.6, 256)), name='rYlGn')

#
# To plot the cross-correlation matrices, use inverted 'hot' colormap
# with reduced dynamic range (From white throuhg yellow to dense red)
#
# cmhot_r = plt.cm.get_cmap('hot_r')
# hotr = ListedColormap(cmhot_r(np.linspace(0.0, 0.7, 256)), name='hotr')

cmjet = plt.cm.get_cmap('jet')
lcmjet = ListedColormap(cmjet(np.linspace(0.1, 0.9, 256)), name='lcmjet')



#
# If -a, loop over the data directories datadir found under dirname.
# Otherwise, n_datadir is 1, n_station is 1, so only one station[0]
# in one experiment is plotted from datadir[0]
#

for iddir in range(n_datadir):

    #
    # Put data files for all the stations in fns:
    #
    fns = glob.glob(datadir[iddir] + '/bandmodel.??????.?.?.?.dat')
    nfns = len(fns)
    if nfns == 0:          # No data in the directory
        if not process_all:
            print('WARNING: No data on path ' + datadir[iddir])
        fwarn.write('WARNING: No data on path ' + datadir[iddir] + '\n')
        fwarn.flush()
        continue  # ======================================================= >>>

    #
    # Extract from path the experiment code like /3686/
    #
    exc = re.findall('\/[0-9]{4}\/', os.path.abspath(datadir[iddir]) )
    exc = exc[0][1:-1]
    #
    # Extract from data file name the experiment name like .VT9050.
    # or .b17337.
    #
    # Any of the filenames can be picked; we pick the 0-th:
    #
    exn = re.findall('(\.[A-y]{2}[0-9]{4}\.|\.[A-y][0-9]{5}\.)', fns[0])
    exn = exn[0][1:-1]

    #
    # Prepare the beginning of line in the 'selections' file
    # Each line in the 'selections' file starts with the experiment code
    # followed by the experiment name
    #
    sel_bp = exc + ' ' + exn + '     '

    #
    # If -a, loop over the stations specified in -s station string of letters.
    # Otherwise, n_station is 1, so only one station[0] in one experiment
    # is plotted from datadir[0]
    #

    #
    # Progress
    #
    if process_all:
        sys.stdout.write("\r# %d of %d, Exp. %s %s" % \
                         (iddir+1, n_datadir, exc, exn))
        sys.stdout.flush()


    for istn in range(n_station):

        fnames = glob.glob(datadir[iddir] + '/bandmodel.??????.' + \
                           station[istn] + '.?.?.dat')
        nbandpol = len(fnames)
        if nbandpol == 0:          # No data for the station
            if not process_all:
                print('WARNING: No data for station ' + station[istn] + \
                      ' on path ' + datadir[iddir])
            fwarn.write('WARNING: No data for station ' + station[istn] + \
                        ' on path ' + datadir[iddir] + '\n')
            fwarn.flush()
            continue  # =================================================== >>>
        else:
            fnames.sort()


        st_exc_exn = station[istn] + '_' + exc + '_' + exn
        st_exc_exn_mf = station[istn] + '_' + exc + '_' + exn + str_mf
        fig_bandpol = outdir + 'bandpol_' + st_exc_exn_mf + '.png'
        fig_xcorrmx = outdir + 'xcorrmx_' + st_exc_exn + '.png'
        #fig_bandpol = outdir + 'bandpol_' + st_exc_exn_mf + '.eps'
        #fig_xcorrmx = outdir + 'xcorrmx_' + st_exc_exn + '.eps'
        fname_bandpol = outdir + 'bandpol_' + st_exc_exn_mf + '.txt'

        #
        # Read all the 8 channel data into the array list datlist.
        # In case the datasets have different lengths, find the minimum length
        # min_ndat to afterwards truncate the dataset arrays in datlist
        #
        min_ndat = 1000000000
        datlist = []
        ndats = np.zeros(nbandpol, dtype=int)
        for ix in range(nbandpol):
            dat_ix = np.loadtxt(fnames[ix], dtype=fields)
            ndats[ix] = dat_ix.shape[0]
            if dat_ix.shape[0] < min_ndat:
                min_ndat = dat_ix.shape[0]
            datlist.append(dat_ix)

        for ix in range(nbandpol):  # Reduce arrays in datlist to minimum size
            if ndats[ix] != min_ndat:
                datlist[ix] = datlist[ix][:min_ndat]

        dat = np.array(datlist)  # Convert list of struct. arrays to array
        ndat = np.size(dat,1)    # Height of a column
        delps = np.zeros((nbandpol,ndat), dtype=float)  # Cable delays (ps)

        #
        # The cable delays for all band-pols are put in 2D appay delps
        #
        for ix in range(nbandpol):
            delps[ix,:] = dat[ix]['delay_ps']


        #
        # Assume time data the same for all bands. Use time from AX (ie 0-th).
        #
        tyear =   dat[0]['year'].astype(int)
        tdoy =    dat[0]['doy'].astype(int)
        thour =   dat[0]['hour'].astype(int)
        tminute = dat[0]['minute'].astype(int)
        tsecond = dat[0]['second'].astype(int)

        datim0 = datetime.datetime(int(tyear[0]), 1, 1) + \
                 datetime.timedelta(int(tdoy[0]) - 1)
        ttuple0 = datim0.timetuple()
        tstamp0 = time.mktime(ttuple0)

        exp_doy_time = datim0.strftime('%j, %H:%M:%S')


        t_sec = np.asarray(tsecond, dtype=float)

        for itim in range(ndat):
            datim = datetime.datetime(int(tyear[itim]), 1, 1, \
                                      int(thour[itim]), int(tminute[itim]), \
                                      int(tsecond[itim])) + \
                    datetime.timedelta(int(tdoy[itim]) - 1)
            ttuple = datim.timetuple()
            tstamp = time.mktime(ttuple)
            t_sec[itim] = tstamp

        t_hr = (t_sec - tstamp0)/3600.    # Time in hours


        #
        # Clean the data
        #
        # Apply median filter to delps, save the filtered data in delps_mf
        #
        delps_mf = np.zeros_like(delps)
        for ibp in range(nbandpol):
            delps_mf[ibp,:] = medfilt(delps[ibp,:], 21)
        #
        # Remove obvious spikes > +-1000
        # Replace them with the median values
        #
        for ibp in range(nbandpol):
            ispike = np.where(np.abs(delps[ibp,:]) > 1000.)
            delps[ibp,ispike] = delps_mf[ibp,ispike]

        #
        # Compute [nbandpol X nbandpol] correlation matrix of rows
        # of delps[nbandpol,:]
        #
        Rxx_full = np.corrcoef(delps_mf)

        #
        # Assume all the bandpols are good
        #
        bp_good = [i for i in range(nbandpol)]
        bp_bad =  []


        #
        # Compute medians for the rows (or columns) of the correlation matrix
        # Low median points at too weak correlation of a selected bandpol with
        # other bandpols.
        # Also, put indices of the rows/columns containing all NaN-s (or other
        # non-finites) to bp_bad list and remove them from bp_good list.
        #
        corr_median = np.zeros(nbandpol)
        for ibp in range(nbandpol):
            row_Rxx = np.concatenate((Rxx_full[ibp,:ibp], Rxx_full[ibp,ibp+1:]))
            ixrfin = np.isfinite(row_Rxx)
            row_Rxx_fin = row_Rxx[ixrfin]   # Leave only finite elements
            if len(row_Rxx_fin) > 0:        # At least one in row is finite
                corr_median[ibp] = np.median(row_Rxx_fin)
            else: # All elements of row are NaNs: ibp-th row and column are bad
                corr_median[ibp] = np.NaN
                bp_bad.append(ibp)
                bp_good.remove(ibp)

        #
        # Log files
        #
        # fmedi = open(fname_median, 'w')  # With median values
        frmul = open(fname_bandpol, 'w')   # With R_mult values

        #
        # Compute the multiple correlation coefficients for every bandpol.
        #
        R_mult = mult_corr(Rxx_full, bp_good, bad_nans=True)

        #
        # Log file: cross-correlation medians and multiple correlations
        #
        write_title(frmul, 'Iterative Selection of Good Band-Pol Channels.', \
                    station[istn], exn, exc, bp_sym, threshold_median)
        write_xcorrmx(frmul, 30*' ' + 'Cross-Correlation Matrix', \
                      Rxx_full, bp_good, station[istn], exn, exc, bp_sym)
        write_numbers(frmul, '       Median ', corr_median, bp_good)
        write_numbers(frmul, '   Multi-Corr ', R_mult, bp_good)
        frmul.write('\n')


        #
        # Successively remove the bandpols with medians of the rows (or
        # columns) of the correlation matrix Rxx_full below the threshold for
        # medians (if there are any)
        #

        R_list = [R_mult]
        R_mult_good = np.copy(R_mult)
        corr_med_good = np.copy(corr_median)
        nbp_good = nbandpol
        iiter = 1

        idx_minmed = np.nanargmin(corr_median) # Index of row with min. median

        while corr_med_good[idx_minmed] < threshold_median:
            if nbp_good <= 2:
                break # ================================================== >>>
            R_mult_good[idx_minmed] = np.NaN
            bp_bad.append(idx_minmed)
            bp_good.remove(idx_minmed)
            nbp_good = len(bp_good)

            #
            # Compute row medians for the rest of bandpols
            #
            corr_med_good = np.zeros(nbandpol)
            corr_med_good[:] = np.NaN

            for ibp in bp_good:
                row_Rxx_list = []
                for ix in range(nbandpol):
                    if (ix in bp_good) and (ix != ibp):
                        row_Rxx_list.append(Rxx_full[ibp,ix])
                row_Rxx = np.array(row_Rxx_list)
                corr_med_good[ibp] = np.median(row_Rxx)

            idx_minmed = np.nanargmin(corr_med_good)

            R_mult_good = mult_corr(Rxx_full, bp_good, bad_nans=True)

            write_xcorrmx(frmul, ' Iteration ' + str(iiter) + \
                          '                  Cross-Correlation Matrix.',\
                          Rxx_full, bp_good, station[istn], exn, exc, bp_sym)

            write_numbers(frmul, '       Median ', corr_med_good, bp_good)
            write_numbers(frmul, '    Mult-Corr ', R_mult, bp_good)
            frmul.write('\n')

            iiter += 1

        #
        # In case only two band-pols left, check if any of them satisfy
        #
        bp_good_copy = copy.copy(bp_good)
        for ibp in bp_good_copy:
            if corr_med_good[ibp] < threshold_median:
                R_mult_good[ibp] = np.NaN
                bp_bad.append(ibp)
                bp_good.remove(ibp)


        frmul.close()



        #
        # Prepare sel_st_bp, a string with the station letter and selected
        # band-pols in the format suitable for passing to the program
        # pcc_select.py in option '-s', like 'E:BX,BY,CX,CY'
        #
        #
        strpol =  'XY'
        strband = 'ABCD'

        sel_st_bp = station[istn] + ':'
        for ibp in range(nbandpol):
            ip = ibp % 2         # 0 1 0 1 0 1 0 1 0 1 0 1 ...
            ib = ibp // 2        # 0 0 1 1 2 2 3 3 0 0 1 1 ...
            if ibp in bp_good:
                sel_st_bp += strband[ib] + strpol[ip] + ','
        sel_st_bp = sel_st_bp[:-1]


        #
        # Save on file the selected band-pols for one experiment and one station
        #
        # fname_selbp = outdir + 'selbp_' + st_exc_exn + '.txt'
        # fselbp = open(fname_selbp, 'w')
        # fselbp.write(sel_st_bp)
        # fselbp.write('\n')
        # fselbp.close()

        #
        # Append the selected band-pols for current station to the line
        # in file 'selections'
        #
        sel_bp += sel_st_bp + '  '

        #
        # Do not create and do not save plots
        #
        if no_plots: continue  # ========================================== >>>

        #
        # Create a plot for each band/pol we have data for (on a 2X4 grid)
        #
        fig = plt.figure(figsize=(8.5,11))
        fig.suptitle("Exp. " + exn + " (code " + exc + \
                     "), Station " + station[istn] + \
                     ". Delay for bands ABCD:XY, Median and R_mult.")
        strpol =  'XY'
        strband = 'ABCD'
        iplot = 0
        for ibp in range(nbandpol):
            ip = ibp % 2         # 0 1 0 1 0 1 0 1 0 1 0 1 ...
            ib = ibp // 2        # 0 0 1 1 2 2 3 3 0 0 1 1 ...
            iplot = iplot + 1   # subplot index starts from 1

            band_pol = strband[ib] + ':' + strpol[ip]

            ax = plt.subplot(4, 2, iplot)

            ax.plot(t_hr, delps_mf[ibp,:], 'g', clip_on=False, lw=2.)
            ax.plot(t_hr, delps[ibp,:], 'b.', markersize=3, clip_on=False)

            xmin, xmax = ax.get_xlim()
            ymin, ymax = ax.get_ylim()
            y_text = ymin + 0.90*(ymax - ymin)
            x_text = xmin + 0.05*(xmax - xmin)

            if ibp in bp_good:
                #
                # Print in axes correlation median and mult. correlation
                #
                ax.text(x_text, y_text, 'Median:%6.3f;   Mult-corr:%6.3f' % \
                        (corr_med_good[ibp], R_mult_good[ibp]), size=12)

            else:
                if np.isnan(R_mult[ibp]):
                    ax.text(x_text, y_text, ' No Data', size=12)
                else:
                    ax.text(x_text, y_text, 'Median:%6.3f;   Mult-corr:%6.3f' %\
                        (corr_median[ibp], R_mult[ibp]), size=12)
                y_text_rej = ymin + 0.50*(ymax - ymin)
                x_text_rej = xmin + 0.35*(xmax - xmin)
                ax.text(x_text_rej, y_text_rej, 'Rejected', color='r', size=14)

            ax.set_ylabel(band_pol + " delay (ps)")
            ax.grid(True)

            if iplot == 7 or iplot == 8: # two bottom plots
                ax.set_xlabel("hours since doy " + exp_doy_time)

        fig.tight_layout()
        fig.subplots_adjust(top=0.94)

        fig.savefig(fig_bandpol)

        #
        # Save the plot of cross-correlation matrix
        #
        # Do not plot the half of cross-correlation matrix under
        # the diagonal
        #
        Rxx_nan = np.copy(Rxx_full)
        for ix in range(nbandpol):
            Rxx_nan[ix,:(ix+1)] = np.NaN

        #
        # Use inverted 'hot' colormap with reduced dynamic range
        # (From white throuhg yellow to dense red)
        #
        n0_7 = np.arange(8)  # Tick values for

        fig2 = plt.figure(figsize=(6,5));
        ax2 = plt.subplot(111)

        xcorimg = ax2.imshow(Rxx_nan, interpolation='none', cmap=lcmjet, \
                             vmin=-1., vmax=1.);
        ax2.set_xticks(n0_7)
        ax2.set_xticklabels(bp_sym)
        ax2.tick_params(axis='x', labeltop='on')
        ax2.set_yticks(n0_7)
        ax2.set_yticklabels(bp_sym)
        fig2.colorbar(xcorimg, shrink=0.8)
        fig2.text(0.1, 0.95, 'Cross-Correlation Matrix. Station ' + \
                  station[istn] + ', Exp. ' + exn + ', Code ' + exc)

        fig2.savefig(fig_xcorrmx)


        if show_graph:
            fig.show()
        else:
            plt.close(fig)

        if show_xcorrmx:
            fig2.show()
        else:
            plt.close(fig2)

    if nbandpol != 0:          # Data for the station exist
        fsel.write(sel_bp + '\n')
        fsel.flush()
        if not process_all:
            print('code exname     selected band-pols')
            print(sel_bp)

sys.stdout.flush()

fsel.close()
fwarn.close()
