#!/usr/bin/env python3
#
# Compute Tsys vs time for Medusa UWB data
#   Based on code my Lawrence Toomey
#
import os, sys
import numpy as np
import warnings
from datetime import datetime
from astropy.time import Time, TimeDelta
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

__version__ = '0.2'
__author__ = 'Chris Phillips'


def read_raw_data(filename, hsize=4096):
    """ Open a raw data file and read the header and data

        :param string filename: name of file to open
        :param int hsize: header size in bytes
        :return dict hdr: contents of header as dictionary object
        :return numpy.ndarray data: data as numpy.ndarray object
    """
    f_size = os.path.getsize(filename)
    hdr = {}

    with open(filename, 'rb') as fh:

        # read the file header and populate dict object
        hdr_buf = fh.read(hsize)
        hdr_buf = hdr_buf.decode('utf-8')

        for line in hdr_buf.split('\n'):
            try:
                k, v = line.split(None, 1)
                hdr[k] = v
            except ValueError:
                pass

        n_bit = int(hdr['NBIT'])
        if n_bit == 4:
            n_bit = 8
        n_ant = int(hdr['NANT'])

        if 'CONTINUUM_OUTNPOL' in hdr.keys():
            n_pol = int(hdr['CONTINUUM_OUTNPOL'])
        elif 'CONTINUUM_OUTSTOKES' in hdr.keys():
            n_pol = int(hdr['CONTINUUM_OUTSTOKES'])
            if n_pol == 3:
                n_pol = 4
        else:
            n_pol = int(hdr['NPOL'])
        n_chan = int(hdr['NCHAN'])

        t0 = Time(datetime.strptime(hdr['UTC_START'].strip(), '%Y-%m-%d-%H:%M:%S'))
        t_mjd = Time(t0.mjd, format='mjd')

        # NOTE: NBIN introduced for calibration data > UTC 2019-04-02
        td_0 = Time(datetime.strptime('2019-04-02-01:25:28', '%Y-%m-%d-%H:%M:%S'))
        td_mjd = Time(td_0.mjd, format='mjd')

        if t_mjd < td_mjd and filename.endswith('.cal'):
            n_bin = 1
        else:
            n_bin = int(hdr['NBIN'])

        hdr['mjd'] = t_mjd

        # populate numpy.ndarray with data
        data_buf = fh.read(f_size - hsize)

        # set data type according to nbits
        if n_bit in (8, 16):
            data = np.frombuffer(data_buf, dtype='int%i' % n_bit)
            print(type(data))
            print(data.shape)
        elif n_bit == 32:
            data = np.frombuffer(data_buf, dtype='f4')
        else:
            raise NotImplementedError

    if hsize != int(hdr['HDR_SIZE']):
        warnings.warn("Header size loaded != HDR_SIZE in header")

    if int(hdr['FILE_SIZE']) + hsize != os.path.getsize(filename):
        warnings.warn("File size in header != actual file size.")

    data = data.reshape((-1, n_ant, n_pol, n_chan, n_bin))

    return hdr, data


if __name__ == "__main__":
    import argparse

    ap = argparse.ArgumentParser()
    ap.add_argument('files', help="Path to raw data files to read", nargs="+")
    ap.add_argument('-p', '-plot', '--plot', help="Plot Tcal", action="store_true")
    ap.add_argument('-P', '-point', '--point', help="Plot as points", action="store_true")
    ap.add_argument('-nozoom', '--nozoom', help="Don't separate zooms", action="store_true")
    #ap.add_argument('-t', '-tcal', '--tcal', help="Compute Tcal", action="store_true")                    
    #ap.add_argument('-t', '-tcal', '--tcal', help="Compute Tcal", action="store_true")                    
    #ap.add_argument('-f', '-freq', '--freq', help="Plot wrt freq, not channel number", action="store_true")                    
    args = ap.parse_args()

    doplot = args.plot

    if doplot:
        times = []
        yval = []
        if args.point:
            point = '.'
        else:
            point = ''

    mjd0 = None

    for file in args.files:
    
        f_hdr, f_data = read_raw_data(file)

        if mjd0 is None: mjd0 = f_hdr['mjd']
        #print("***MJD(", type(mjd0))

        offset = float(f_hdr['OBS_OFFSET'])
        bytespersec = float(f_hdr['BYTES_PER_SECOND'])
        toff = offset / bytespersec        # offset in seconds
        if f_data.shape[4]==2: toff /= 16  # Header wrong for averaged data
        nchan = float(f_hdr['NCHAN'])
        npol = int(f_hdr['NPOL'])
        bw = float(f_hdr['BW'])
        freq = float(f_hdr['FREQ'])

        thisMJD = f_hdr['mjd'] + toff/(24*60*60)  # Astropy Time
        #toff += (f_hdr['mjd'] - mjd0)*24*60*60    # Astropy TimeDelta
        toff += (f_hdr['mjd'] - mjd0).sec    # Float, seconds

        chanBW = bw/nchan
    
        nzoom = 0
        chanrange = [np.arange(nchan, dtype=int), []]

        #if args.freq:
        #    freq = float(f_hdr['FREQ'])
        #    bw = float(f_hdr['BW'])
        #    xvals = np.linspace(freq-bw/2,freq+bw/2,nchan,False)
        #else:
        #    xvals = np.arange(nchan)

        if not args.nozoom and (f_hdr['PERFORM_VLBI']=='true' or f_hdr['PERFORM_VLBI']=='1'):
            if f_hdr['ZOOM1_ACTIVE'] == '1':
                zoomBW = float(f_hdr['ZOOM1_BW'])
                zoomFreq = int(f_hdr['ZOOM1_FREQUENCY'])
                zoomChan = round(zoomBW / chanBW)
                startZoom = zoomFreq - zoomBW/2
                z1 = round((startZoom - (freq-bw/2))/chanBW)
                z2 = z1 + round(zoomBW/chanBW) - 1
                chanrange[0] = np.arange(z1,z2+1, dtype=int)
                nzoom = 1

            if f_hdr['ZOOM2_ACTIVE'] == '1':
                zoomBW = float(f_hdr['ZOOM2_BW'])
                zoomFreq = int(f_hdr['ZOOM2_FREQUENCY'])
                zoomChan = round(zoomBW / chanBW)
                startZoom = zoomFreq - zoomBW/2
                z1 = round((startZoom - (freq-bw/2))/chanBW)
                z2 = z1 + round(zoomBW/chanBW) - 1
                chanrange[nzoom] = np.arange(z1,z2+1, dtype=int)
                nzoom += 1

        if nzoom==0: nzoom=1 # Full bandwidth

        if doplot:
            #xval.append(toff.value)
            times.append(thisMJD)
            #xval.append(tzero + datatime.timedelta)(seconds=)
        else:
            #print("{} {:7.1f} ".format(thisMJD.iso, toff.value), end='')
            print("{} {:7.1f} ".format(thisMJD.iso, toff), end='')

        tcal = []
        for z in range(nzoom):
            tchans  = chanrange[z]
            for p in range(min(npol,2)):
                t = (f_data[0, 0, p, tchans, 0]/(f_data[0, 0, p, tchans, 1]-f_data[0, 0, p, tchans, 0])).mean()
                if doplot:
                    tcal.append(t)
                else:
                    print(" {:.1f}".format(t), end='')
        if doplot: yval.append(tcal)

        if not doplot: print()

    if doplot:
        #times = TimeDelta(times, format='sec')
        #tzero = Time('2020-01-01 00:00:00')
        #xvals = tzero+times
        xvals = Time(times)
        
        yval = list(map(list, zip(*yval)))
        for y in yval:
            #plt.plot(xval,y,point)
            plt.plot_date(xvals.plot_date,y,point,xdate=True)
            
        fmt = mdates.DateFormatter('%H:%M')
        plt.gca().xaxis.set_major_formatter(fmt)
        plt.show()


# UTC_START           2020-12-07-03:55:44    
#ZOOM1_ACTIVE        1
#ZOOM1_BW            64
#ZOOM1_FREQUENCY     22332
#ZOOM1_MODE          vlbi
#ZOOM2_ACTIVE        0
#ZOOM2_BW            4
#ZOOM2_FREQUENCY     832
#ZOOM2_MODE          baseband
#CALFREQ             100
#FREQ                22364.000000           
#BW                  128.000000             
#NCHAN               128                    
#TSYS_AVG_TIME       5
#OBS_OFFSET          464322560              
#BYTES_PER_SECOND    819.200000             
#BYTES_PER_SECOND    13107.200000           


