#!/usr/bin/env python3

import numpy as np
from astropy.time import Time
import argparse, sys, math

parser = argparse.ArgumentParser()
parser.add_argument('-avpol', '--avpol', help="Average pols", action="store_true")
parser.add_argument('-p', '-pol', '--pol', help="Select n'th polarisation/IF", type=int, default=0)
parser.add_argument('-s', '--show', '-show', help="Display plot to screen", action="store_true")
parser.add_argument('-f', '--freqref', '-freqref', help="Reference Frequency to subtract from all values (MHz)", type=float)
parser.add_argument('-f0', '--f0', help="First frequency point to plot", type=float)
parser.add_argument('-f1', '--f1', help="Last frequency point to plot", type=float)
parser.add_argument('-a', '-av', '--average', '-average', help="Average N spectral points", type=int)
parser.add_argument('radarfile', help="Radar Spectrum", nargs='+')
args = parser.parse_args()

pol = args.pol
if (pol<0):
    printf("Negative polarisation is meaningless. Selecting first pol (0)")
    pol = 0

first = True
totalInt = 0

chanrange = None

doplot = args.show

times = []
freqs = []

if doplot:
    import matplotlib.pyplot as plt

doSlice = False
for radarfile in args.radarfile:
    print("Reading ", radarfile, file=sys.stderr)
    fd = open(radarfile, 'rb')

    (version, nchan, nTotal, firstChan, nIF, bandwidth, mjd, seconds) = np.fromfile(fd, dtype=np.uint32, count=8)
    (tInt) = np.fromfile(fd, dtype=np.float32, count=1)
    (LO) = np.fromfile(fd, dtype=np.float64, count=1)

    if (pol >= nIF):
        print("Error selected pol {} too large".format(pol))
        sys.exit()

    mjd += seconds/(60*60*24)
    time = Time(mjd, format='mjd').iso

    times.append(mjd)

    if (first):
        chanWidth = bandwidth/nTotal
        f0 = (firstChan / nTotal) * bandwidth + LO - bandwidth/2
        if args.freqref is not None:
            f0 -=args.freqref*1e6

        f1 = f0 + chanWidth * nchan
        xvals = (np.arange(nchan).astype('float64') * chanWidth + f0) / 1e6  # MHz 

        if (args.f0 is not None or args.f1 is not None):
            bchan = 0
            echan = nchan-1
            if args.freqref is not None:
                fscale = 1
            else:
                fscale = 1e6
            if args.f0 is not None:
                bchan = int(np.rint((args.f0*fscale-f0)/chanWidth))
            if args.f1 is not None:
                echan = int(np.rint((args.f1*fscale-f0)/chanWidth)+1)
            doSlice = True
            xvals = xvals[bchan:echan]

        if args.average is not None:
            av = args.average
            xvals = np.mean(xvals[:(len(xvals)//av)*av].reshape(-1,av), axis=1)

        if args.freqref is not None:
            xvals *= 1e6 # Back to Hz
            
    data = np.fromfile(fd, dtype='f4').reshape(-1,nIF,nchan)

    nInt = data.shape[0] # Number of integrations in this file

    this_specAv = np.mean(data, axis=0).astype('float64')

    if doSlice:
        this_specAv = this_specAv[:,bchan:echan]

    if args.avpol:
        this_specAv = np.sum(this_specAv,axis=0)
    else:
        this_specAv = this_specAv[pol,:]

    if args.average is not None:
        av = args.average
        this_specAv = np.mean(this_specAv[:(len(this_specAv)//av)*av].reshape(-1,av), axis=1)

    fd.close()

    i = np.argmax(this_specAv)
    print("{} {}  {}".format(time,radarfile,xvals[i]))

    freqs.append(xvals[i])


if doplot:
    times = np.array(times)
    times -= times[0]
    times *= 24*60*60
    
    plt.plot(times, freqs)
    plt.xlabel('Time (sec)')
    plt.ylabel('Freq (Hz)')
    plt.show()
