#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
from astropy.time import Time
import argparse, sys, math

def db(x):
  """ Convert linear value to dB value """
  return 10*np.log10(x)

parser = argparse.ArgumentParser()
parser.add_argument('-F', '--outfile', '-outfile', help="output file to save plot as")
parser.add_argument('--dpi', help="Resolution to save output as", type=int, default=300)
parser.add_argument('-s', '--show', '-show', help="Display plot to screen", action="store_true")
parser.add_argument('-o', '--offset', '-offset', help="Print frequency of offset", action="store_true")
parser.add_argument('-avpol', '--avpol', help="Average pols", action="store_true")
parser.add_argument('-p', '-pol', '--pol', help="Select n'th polarisation/IF", type=int, default=0)
parser.add_argument('-db', '--db', help="Plot in dB", action="store_true")
parser.add_argument('-f', '--freqref', '-freqref', help="Reference Frequency to subtract from all values (MHz)", type=float)
parser.add_argument('-f0', '--f0', help="First frequency point to plot", type=float)
parser.add_argument('-f1', '--f1', help="Last frequency point to plot", type=float)
parser.add_argument('-a', '-av', '--average', '-average', help="Average N spectral points", type=int)
#parser.add_argument('-l', '-label', '--label', help="label plots for when overplotting", nargs="+",type=float)
parser.add_argument('radarfile', help="Radar Spectrum", nargs=2)
args = parser.parse_args()

first = True
totalInt = 0

show=(args.show or (args.outfile==None and args.offset==False))
pol = args.pol
if (pol<0):
    printf("Negative polarisation is meaningless. Selecting first pol (0)")
    pol = 0

chanrange = None
corrMode = 'same'

spec = []

doSlice = False
for radarfile in args.radarfile:
    print("Reading ", radarfile, file=sys.stderr)
    fd = open(radarfile, 'rb')

    (version, nchan, nTotal, firstChan, nIF, bandwidth, mjd, seconds) = np.fromfile(fd, dtype=np.uint32, count=8)
    (tInt) = np.fromfile(fd, dtype=np.float32, count=1)
    (LO) = np.fromfile(fd, dtype=np.float64, count=1)

    if (pol >= nIF):
        print("Error selected pol {} too large".format(pol))
        sys.exit()

    mjd += seconds/(60*60*24)
    time = Time(mjd, format='mjd').iso

    if (first):
        chanWidth = bandwidth/nTotal
        f0 = (firstChan / nTotal) * bandwidth + LO - bandwidth/2
        if args.freqref is not None:
            f0 -=args.freqref*1e6

        f1 = f0 + chanWidth * nchan

        if (args.f0 is not None or args.f1 is not None):
            bchan = 0
            echan = nchan-1
            if args.freqref is not None:
                fscale = 1
            else:
                fscale = 1e6
            if args.f0 is not None:
                bchan = int(np.rint((args.f0*fscale-f0)/chanWidth))
            if args.f1 is not None:
                echan = int(np.rint((args.f1*fscale-f0)/chanWidth)+1)
            doSlice = True
            coreMode = 'valid'

    data = np.fromfile(fd, dtype='f4').reshape(-1,nIF,nchan)

    nInt = data.shape[0] # Number of integrations in this file

    this_specAv = np.mean(data, axis=0).astype('float64')

    if doSlice:
        if first:
          delta = echan-bchan+1
          bbchan = bchan-(delta//2)
          eechan = echan + (delta//2)
          if bbchan<0:
            bbchan=0
          if eechan>nchan-1:
            eechan = nchan-1
          this_specAv = this_specAv[:,bbchan:eechan]
        else:
          this_specAv = this_specAv[:,bchan:echan]

    if args.avpol:
        this_specAv = np.sum(this_specAv,axis=0)
    else:
        this_specAv = this_specAv[pol,:]

    if args.average is not None:
        av = args.average
        this_specAv = np.mean(this_specAv[:(len(this_specAv)//av)*av].reshape(-1,av), axis=1)

    if args.db:
        this_specAv = db(this_specAv)

    spec.append(this_specAv)
    
    fd.close()
    first = False
    
corr = np.correlate(spec[0],spec[1],corrMode)

nlag  = corr.shape[0]
xvals = (np.arange(nlag)-(nlag//2))*chanWidth

if args.offset:
    i = np.argmax(corr)
    print("{} - {}: Offset = {}".format(args.radarfile[0],args.radarfile[1],xvals[i]))

if (show or args.outfile is not None):
    ax = plt.gca()
    plt.plot(xvals,corr)
    plt.text(0.03, 0.95, args.radarfile[0], horizontalalignment='left', verticalalignment='center',transform=ax.transAxes)
    plt.text(0.03, 0.90, args.radarfile[1], horizontalalignment='left', verticalalignment='center',transform=ax.transAxes)
    plt.xlabel("Correlation Coefficient")
    plt.xlabel("Offset (Hz)")

if (args.outfile is not None):
    plt.savefig(args.outfile, dpi=args.dpi)

if show:
    plt.show()

