#!/usr/bin/env python3

import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from matplotlib.dates import DateFormatter
import argparse, sys, math

def db(x):
    """ Convert linear value to dB value """
    return 10*np.log10(x)

def timeAv(spec, r):
    """ Average every r'th row in a 2D array  """
    c = spec.shape[1]
    return spec[:(spec.shape[0]//r)*r,:].transpose().reshape(-1,r).mean(1).reshape(c,-1).transpose()

def freqAv(spec, c):
    """ Average every c'th column in a 2D array  """
    r = spec.shape[0]
    return spec[:,:(spec.shape[1]//c)*c].reshape(-1,c).mean(1).reshape(r,-1)

parser = argparse.ArgumentParser()
parser.add_argument('-F', '--outfile', '-outfile', help="output file to save plot as")
parser.add_argument('--dpi', help="Resolution to save output as", type=int, default=300)
parser.add_argument('-avpol', '--avpol', help="Average pols", action="store_true")
parser.add_argument('-s', '--show', '-show', help="Display plot to screen", action="store_true")
parser.add_argument('-m', '--max', '-max', help="Plot max in channel range, not mean", action="store_true")
parser.add_argument('-db', '--db', help="Plot in dB", action="store_true")
parser.add_argument('-f', '--freqref', '-freqref', help="Reference Frequency to subtract from all values (MHz)", type=float)
parser.add_argument('-f0', '--f0', help="First frequency point to plot", type=float)
parser.add_argument('-f1', '--f1', help="Last frequency point to plot", type=float)
parser.add_argument('-t', '-tint', '--tint', help="Average to N seconds", type=float)
parser.add_argument('-l', '-label', '--label', help="Label Plots", action="store_true")
parser.add_argument('-title', '--title', help="Plot Title")
parser.add_argument('radarfile', help="Radar Spectrum", nargs="+")
args = parser.parse_args()

first = True

show=(args.show or (args.outfile==None))

fig, ax = plt.subplots()
date_form = DateFormatter("%H:%M:%S")
ax.xaxis.set_major_formatter(date_form)

nInt = None   # Number of integrations to smooth

ampPlot = []  # Use list for waterfall data to save across files
timePlot = []

doSlice = False

for radarfile in args.radarfile:
    print("Reading ", radarfile)
    fd = open(radarfile, 'rb')

    (version, nchan, nTotal, firstChan, nIF, bandwidth, mjd, seconds) = np.fromfile(fd, dtype=np.uint32, count=8)
    tInt = np.fromfile(fd, dtype=np.float32, count=1).item()
    LO = np.fromfile(fd, dtype=np.float64, count=1).item()
    mjdSec = mjd * 60*60*24 + seconds

    if (first):
        first = False
        mjd0 = mjdSec
        chanWidth = bandwidth/nTotal
        f0 = (firstChan / nTotal) * bandwidth + LO - bandwidth/2
        if args.freqref is not None:
            f0 -=args.freqref*1e6

        f1 = f0 + chanWidth * nchan
        #xvals = (np.arange(nchan).astype('float64') * chanWidth + f0) / 1e6  # MHz 

        if (args.f0 is not None or args.f1 is not None):
            bchan = 0
            echan = nchan-1
            if args.freqref is not None:
                fscale = 1
            else:
                fscale = 1e6
            if args.f0 is not None:
                bchan = int(np.rint((args.f0*fscale-f0)/chanWidth))
            if args.f1 is not None:
                echan = int(np.rint((args.f1*fscale-f0)/chanWidth)+1)
            doSlice = True
            print(f"Using channel range {bchan} to {echan}")
        if args.tint is not None:
            nInt = int(np.rint(args.tint/tInt))
            if nInt==0:
                print("Integration time {} too small. Not averaging in time".format(args.tint))
                nInt = None
                plotTint = tint
            else:
                plotTint = nInt*tInt
                print("Using interation time of {:.1f} sec".format(plotTime))
        else:
            plotTint = tInt
        
    # Read spectrum from file
    spec = np.fromfile(fd, dtype='f4').reshape(-1,nIF,nchan)

    # Select frequency range
    if doSlice:
        spec = spec[:,:,bchan:echan]

    if args.max:
        spec = np.amax(spec,2)
    else:
        spec = spec.mean(2)

    if nInt is not None:
        spec = timeAv(spec, nInt)

    if args.db:
        spec = db(spec)
        
    specShape = spec.shape
    times = np.broadcast_to(np.arange(specShape[0], dtype=np.float64)[:,np.newaxis], specShape).copy()
    times *= tInt
    times += seconds
    times /= 24*60*60
    times += mjd
    
    ampPlot.append(spec)
    timePlot.append(times)

    fd.close()

ampPlot = np.concatenate(ampPlot)  # List to numpy array
timePlot = np.concatenate(timePlot)
datetimePlot = np.array([[datetime(1858, 11, 17) + timedelta(days=mjd - 50000) for mjd in row] for row in timePlot])

if args.avpol:
    ampPlot = np.sum(ampPlot,axis=1,keepdims=True)
    nIF = 1

    
for i in range(nIF):
    if args.label:
        label="Chan {}".format(i)
    else:
        label=None
    ax.plot(datetimePlot[:,i],ampPlot[:,i],label=label)

ax.set_xlabel('Time')
ax.set_ylabel('Power')
if args.title is not None:
    ax.set_title(args.title)

if (args.outfile is not None):
    plt.savefig(args.outfile, dpi=args.dpi)

if show:
    if args.label: plt.legend()
    plt.show()

