#!/usr/bin/env python3

import os
import numpy as np
import warnings
from datetime import datetime
from astropy.time import Time
import matplotlib.pyplot as plt

__version__ = '1.7'
__author__ = 'Lawrence Toomey'


def read_raw_data(filename, hsize=4096):
    """ Open a raw data file and read the header and data

        :param string filename: name of file to open
        :param int hsize: header size in bytes
        :return dict hdr: contents of header as dictionary object
        :return numpy.ndarray data: data as numpy.ndarray object
    """
    f_size = os.path.getsize(filename)
    hdr = {}

    with open(filename, 'rb') as fh:

        # read the file header and populate dict object
        hdr_buf = fh.read(hsize)
        hdr_buf = hdr_buf.decode('utf-8')

        for line in hdr_buf.split('\n'):
            try:
                k, v = line.split(None, 1)
                hdr[k] = v
            except ValueError:
                pass

        n_bit = int(hdr['NBIT'])
        if n_bit == 4:
            n_bit = 8
        n_ant = int(hdr['NANT'])

        if 'CONTINUUM_OUTNPOL' in hdr.keys():
            n_pol = int(hdr['CONTINUUM_OUTNPOL'])
        elif 'CONTINUUM_OUTSTOKES' in hdr.keys():
            n_pol = int(hdr['CONTINUUM_OUTSTOKES'])
            if n_pol == 3:
                n_pol = 4
        else:
            n_pol = int(hdr['NPOL'])

        n_chan = int(hdr['NCHAN'])

        # NOTE: NBIN introduced for calibration data > UTC 2019-04-02
        t0 = Time(datetime.strptime(hdr['UTC_START'].strip(), '%Y-%m-%d-%H:%M:%S'))
        td_0 = Time(datetime.strptime('2019-04-02-01:25:28', '%Y-%m-%d-%H:%M:%S'))
        t_mjd = Time(t0.mjd, format='mjd')
        td_mjd = Time(td_0.mjd, format='mjd')

        if t_mjd < td_mjd and filename.endswith('.cal'):
            n_bin = 1
        else:
            n_bin = int(hdr['NBIN'])

        # populate numpy.ndarray with data
        data_buf = fh.read(f_size - hsize)

        # set data type according to nbits
        if n_bit in (8, 16):
            data = np.frombuffer(data_buf, dtype='int%i' % n_bit)
            print(type(data))
            print(data.shape)
        elif n_bit == 32:
            data = np.frombuffer(data_buf, dtype='f4')
        else:
            raise NotImplementedError

    if hsize != int(hdr['HDR_SIZE']):
        warnings.warn("Header size loaded != HDR_SIZE in header")

    if int(hdr['FILE_SIZE']) + hsize != os.path.getsize(filename):
        warnings.warn("File size in header != actual file size.")

    data = data.reshape((-1, n_ant, n_pol, n_chan, n_bin))

    return hdr, data


if __name__ == "__main__":
    import argparse

    ap = argparse.ArgumentParser()
    ap.add_argument('file', help="Path to raw data file to read")
    ap.add_argument('-t', '-tcal', '--tcal', help="Compute Tcal", action="store_true")                    
    ap.add_argument('-f', '-freq', '--freq', help="Plot wrt freq, not channel number", action="store_true")                    
    ap.add_argument('-p', '-pol', '--pol', help="What polarisation to plot", default=0, type=int)                    
    ap.add_argument('-c0', '--c0', help="First channel to average for cal32 data", default=None, type=int)                    
    ap.add_argument('-c1', '--c1', help="Last channel to average for cal32 data", default=None, type=int)                    
    args = ap.parse_args()

    f_hdr, f_data = read_raw_data(args.file)

    print('\nHeader information ---------------------------------------------')
    print(type(f_hdr), len(f_hdr))
    for key, val in f_hdr.items():
        print(key, ':', val)

    print('\nData information -----------------------------------------------')
    print(np.info(f_data))
    print(type(f_data))
    print('TSUBINT, N_beam, N_pol, N_chan, N_bin')
    print(f_data.shape)

    pol = args.pol
    npol = f_data.shape[2]

    nchan = float(f_hdr['NCHAN'])
    if args.freq:
        freq = float(f_hdr['FREQ'])
        bw = float(f_hdr['BW'])
        xvals = np.linspace(freq-bw/2,freq+bw/2,nchan,False)
    else:
        xvals = np.arange(nchan)


    print("*** ", f_data.shape)

    # plot first 1e5 samples of baseband data
    if f_data.shape[3] == 1:
        print('\nPlotting first 1e5 samples of baseband data...')
        plt.xlabel('Sample No.')
        plt.ylabel('Signal (arbitary)')
        plt.plot(xvals, f_data[0:10000, 0, 0, 0, 0])
    else:
        nbin = int(f_hdr['NBIN'])
        if nbin==2:
            if args.tcal:
                plt.plot(xvals, f_data[0, 0, pol, :, 0]/(f_data[0, 0, pol, :, 1]-f_data[0, 0, pol, :, 0]))
            else:
                plt.plot(xvals, f_data[0, 0, pol, :, 0], label='Cal off')
                plt.plot(xvals, f_data[0, 0, pol, :, 1], label='Cal on')
                plt.legend()
            plt.xlabel('Frequency Channel No.')
            plt.ylabel('Signal Amplitude (arbitary)')
        else:
            c0 = int(round(nchan/10))
            c1 = int(round(nchan*0.9))

            cc0 = args.c0
            cc1 = args.c1
            if (cc0 is not None):
                if cc0>=0 and cc0<nchan:
                    c0 = cc0
                else:
                    print("Warning -c0 {} is out of range. Ignoring".format(cc0))
            if (cc1 is not None):
                if cc1>=0 and cc1<nchan:
                    c1 = cc1
                else:
                    print("Warning -c1 {} is out of range. Ignoring".format(cc1))
                    
            print("Averaging channels {} to {}".format(c0,c1))
            c1 += 1

            for p in range(min(npol,2)):
                plt.plot(np.arange(nbin), np.mean(f_data[0, 0, p, c0:c1, :], axis=0), label="Pol {}".format(p))
            plt.xlabel('Cal bin')
            plt.ylabel('Signal Amplitude (arbitary)')
            plt.legend()
    plt.show()


