#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
from astropy.time import Time
import argparse, sys, math

def db(x):
  """ Convert linear value to dB value """
  return 10*np.log10(x)

parser = argparse.ArgumentParser()
parser.add_argument('-F', '--outfile', '-outfile', help="output file to save plot as")
parser.add_argument('--dpi', help="Resolution to save output as", type=int, default=300)
parser.add_argument('-s', '--show', '-show', help="Display plot to screen", action="store_true")
parser.add_argument('-avpol', '--avpol', help="Average pols", action="store_true")
parser.add_argument('-db', '--db', help="Plot in dB", action="store_true")
parser.add_argument('-f', '--freqref', '-freqref', help="Reference Frequency to subtract from all values (MHz)", type=float)
parser.add_argument('-f0', '--f0', help="First frequency point to plot (Hz)", type=float)
parser.add_argument('-f1', '--f1', help="Last frequency point to plot (Hz)", type=float)
parser.add_argument('-frest', '--frest', help="Transmit Frequency (MHz)", type=float)
parser.add_argument('-a', '-av', '--average', '-average', help="Average N spectral points", type=int)
parser.add_argument('-overplot', '--overplot', help="Overplot each file", action="store_true")
parser.add_argument('-l', '-label', '--label', help="label plots for when overplotting", nargs="+",type=float)
parser.add_argument('radarfile', help="Radar Spectrum", nargs="+")
args = parser.parse_args()

C = 299792458 # Speed of light

first = True
totalInt = 0

show=(args.show or (args.outfile==None))

chanrange = None

def doplot(plotSpec, nIF):
    if args.avpol:
        plotSpec = np.sum(plotSpec,axis=0,keepdims=True)
        nIF = 1

    for i in range(nIF):
        spec = plotSpec[i,:]
        if args.average is not None:
            av = args.average
            spec = np.mean(spec[:(len(spec)//av)*av].reshape(-1,av), axis=1)
        if args.db:
            spec = db(spec)
        plt.plot(xvals, spec)

doSlice = False
for radarfile in args.radarfile:
    print("Reading ", radarfile)
    fd = open(radarfile, 'rb')

    (version, nchan, nTotal, firstChan, nIF, bandwidth, mjd, seconds) = np.fromfile(fd, dtype=np.uint32, count=8)
    (tInt) = np.fromfile(fd, dtype=np.float32, count=1)
    (LO) = np.fromfile(fd, dtype=np.float64, count=1)

    mjd += seconds/(60*60*24)
    time = Time(mjd, format='mjd').iso

    if (first):
        chanWidth = bandwidth/nTotal
        f0 = (firstChan / nTotal) * bandwidth + LO - bandwidth/2
        if args.freqref is not None:
            f0 -=args.freqref*1e6

        f1 = f0 + chanWidth * nchan
        xvals = (np.arange(nchan).astype('float64') * chanWidth + f0) / 1e6  # MHz 

        if (args.f0 is not None or args.f1 is not None):
            bchan = 0
            echan = nchan-1
            if args.freqref is not None:
                fscale = 1
            else:
                fscale = 1e6
            if args.f0 is not None:
                bchan = int(np.rint((args.f0*fscale-f0)/chanWidth))
            if args.f1 is not None:
                echan = int(np.rint((args.f1*fscale-f0)/chanWidth)+1)
            doSlice = True
            xvals = xvals[bchan:echan]

        if args.average is not None:
            av = args.average
            xvals = np.mean(xvals[:(len(xvals)//av)*av].reshape(-1,av), axis=1)

        if args.freqref is not None or args.frest is not None:
            xvals *= 1e6 # Back to Hz

        if args.frest is not None:
            xvals *= C / (args.frest*1e6)
            
    data = np.fromfile(fd, dtype='f4').reshape(-1,nIF,nchan)

    nInt = data.shape[0] # Number of integrations in this file

    this_specAv = np.mean(data, axis=0).astype('float64')

    if doSlice:
        this_specAv = this_specAv[:,bchan:echan]

    if first or args.overplot:
        first = False
        specAv = this_specAv
        totalInt = nInt
    else:
        specAv = (specAv*totalInt + this_specAv*nInt)/(totalInt+nInt)
        totalInt += nInt

    fd.close()
    if args.overplot:
        doplot(specAv, nIF)


if args.frest is not None:
    plt.xlabel('Velocity m/s')
    
if not args.overplot:
    doplot(specAv, nIF)
    ax = plt.gca()
    plt.text(0.03, 0.95, time, horizontalalignment='left', verticalalignment='center',transform=ax.transAxes)

if (args.outfile is not None):
    plt.savefig(args.outfile, dpi=args.dpi)

if show:
    plt.show()

