#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
#import matplotlib.dates as mdates
from astropy.time import Time
from matplotlib.collections import LineCollection
#import matplotlib.cm as cm
import argparse, sys

def db(x):
    """ Convert linear value to dB value """
    return 10*np.log10(x)

def timeAv(spec, r):
    """ Average every r'th row in a 2D array  """
    c = spec.shape[1]
    return spec[:(spec.shape[0]//r)*r,:].transpose().reshape(-1,r).mean(1).reshape(c,-1).transpose()

def freqAv(spec, c):
    """ Average every c'th column in a 2D array  """
    r = spec.shape[0]
    return spec[:,:(spec.shape[1]//c)*c].reshape(-1,c).mean(1).reshape(r,-1)

def downsample_max(wspec, target_points):
    nrows, ncols = wspec.shape
    factor = max(1, ncols // target_points)  # Approximate bin size
    new_ncols = ncols // factor

    # Trim to exact multiple for easy reshaping
    trimmed = wspec[:, :new_ncols * factor]

    # Reshape and take max along the binning axis
    downsampled = trimmed.reshape(nrows, new_ncols, factor).max(axis=2)
    return downsampled

parser = argparse.ArgumentParser()
parser.add_argument('-F', '--outfile', '-outfile', help="output file to save plot as")
parser.add_argument('--dpi', help="Resolution to save output as", type=int, default=300)
parser.add_argument('-s', '--show', '-show', help="Display plot to screen", action="store_true")
parser.add_argument('-db', '--db', help="Plot in dB", action="store_true")
parser.add_argument('-f', '--freqref', '-freqref', help="Reference Frequency to subtract from all values (MHz)", type=float)
parser.add_argument('-f0', '--f0', help="First frequency point to plot", type=float)
parser.add_argument('-f1', '--f1', help="Last frequency point to plot", type=float)
parser.add_argument('-frest', '--frest', help="Transmit Frequency (MHz)", type=float)
parser.add_argument('-avpol', '--avpol', help="Average pols", action="store_true")
parser.add_argument('-a', '-av', '--average', '-average', help="Average N spectral points", type=int)
parser.add_argument('-t', '-tint', '--tint', help="Average to N seconds", type=float)
parser.add_argument('-p', '-pol', '--pol', help="Select n'th polarisation/IF", type=int, default=0)
parser.add_argument('-tweak', '--tweak', help="Tweak time by factor to allow for non-full integration", type=float)
parser.add_argument('-m', '-mark', '--mark', help="Mark time regions", nargs="+",type=float)
parser.add_argument('-title', '--title', help="Plot Title")
parser.add_argument('-downsample', '--downsample', help="Down sample X-axis, using maximum over 'N', channels", type=int)
parser.add_argument('radarfile', help="Radar Spectrum", nargs="+")
args = parser.parse_args()

first = True

av = args.average
show=(args.show or (args.outfile==None))
pol = args.pol
if (pol<0):
    print("Negative polarisation is meaningless. Selecting first pol (0)")
    pol = 0

nInt = None   # Number of integrations to smooth

C = 299792458 # Speed of light

wspec = []  # Use list for waterfall data to save across files

doSlice = False

startMJD = 0
endMJD = 0
for radarfile in args.radarfile:
    print("Reading ", radarfile)
    fd = open(radarfile, 'rb')

    (version, nchan, nTotal, firstChan, nIF, bandwidth, mjd, seconds) = np.fromfile(fd, dtype=np.uint32, count=8)
    tInt = np.fromfile(fd, dtype=np.float32, count=1).item()
    LO = np.fromfile(fd, dtype=np.float64, count=1).item()

    print("tInt=", tInt)
    
    if (pol >= nIF):
        print("Error selected pol {} too large".format(pol))
        sys.exit()

    if (first):
        first = False
        
        chanWidth = bandwidth/nTotal
        f0 = (firstChan / nTotal) * bandwidth + LO - bandwidth/2
        if args.freqref is not None:
            f0 -=args.freqref*1e6

        f1 = f0 + chanWidth * nchan
        xvals = (np.arange(nchan).astype('float64') * chanWidth + f0) / 1e6  # MHz 

        if args.freqref is not None or args.frest is not None:
            xvals *= 1e6 # Back to Hz

        if (args.f0 is not None or args.f1 is not None):
            bchan = 0
            echan = nchan-1
            if args.freqref is not None:
                fscale = 1
            else:
                fscale = 1e6
            if args.f0 is not None:
                bchan = int(np.rint((args.f0*fscale-f0)/chanWidth))
            if args.f1 is not None:
                echan = int(np.rint((args.f1*fscale-f0)/chanWidth)+1)
            doSlice = True
            xvals = xvals[bchan:echan]

        if av is not None:
            xvals = np.mean(xvals[:(len(xvals)//av)*av].reshape(-1,av), axis=1)

        if args.frest is not None:
            xvals *= C / (args.frest*1e6)
            
        f0 = xvals[0]
        f1 = xvals[-1]
            
        if args.tint is not None:
            nInt = int(np.rint(args.tint/tInt))
            if nInt==0:
                print("Integration time {} too small. Not averaging in time".format(args.tint))
                nInt = None
                plotTint = tInt
            else:
                plotTint = nInt*tInt
                print("Using interation time of {:.1f} sec".format(plotTint))
        else:
            plotTint = tInt

        startMJD = mjd + seconds/(60*60*24)
    endMJD = mjd + seconds/(60*60*24)

    data = np.fromfile(fd, dtype='f4')
    if args.avpol:
        spec = np.sum(data.reshape(-1,nIF,nchan),axis=1)  # Sum pols. Now 2D array
    else:
        spec = data.reshape(-1,nIF,nchan)[:,pol,:]  # Single pol selected. Now 2D array

    print("Nint =", spec.shape[0])

    endMJD += spec.shape[0]*tInt/(60*60*24)

    if doSlice:
        spec = spec[:,bchan:echan]

    if av is not None:
        spec = freqAv(spec, av)
    
    if nInt is not None:
        spec = timeAv(spec, nInt)

    if args.db:
        spec = db(spec)
    
    wspec.append(spec)

    fd.close()

wspec = np.concatenate(wspec)  # List to numpy array

if args.downsample is not None:
    wspec = downsample_max(wspec, args.downsample)

print("Total Int = ", wspec.shape[0]*plotTint)
print("Total Time spectra= ", wspec.shape[0])

if args.tweak is not None:
    plotTint *= args.tweak

mjd1 = Time(startMJD, format='mjd')
mjd2 = Time(endMJD, format='mjd')

if (args.outfile is not None):
    plt.rcParams.update({
        "axes.titlesize": 8,
        "axes.labelsize": 7,
        "xtick.labelsize": 6,
        "ytick.labelsize": 6,
        "legend.fontsize": 6,
    })

plt.imshow(wspec,cmap='viridis', aspect='auto', extent=[f0,f1,0,wspec.shape[0]*plotTint])
plt.gca().invert_yaxis()
if args.frest is not None:
    plt.xlabel('Doppler Velocity (m/s)')
elif args.freqref is not None:
    plt.xlabel('Frequency (Hz)')
else:
    plt.xlabel('Frequency (MHz)')
plt.ylabel('Time (sec)')
if args.title is not None:
    plt.title(args.title)

if args.mark is not None:
    lines = []
    for m in args.mark:
        lines.append([(f0,m),(f1,m)])
    lc = LineCollection(lines, color=["y"], lw=1)
    plt.gca().add_collection(lc)

if (args.outfile is not None):
    plt.tight_layout()
    plt.savefig(args.outfile, dpi=args.dpi)

if show:
    plt.show()
