#!/usr/bin/env python3

import sys, argparse
import numpy as np
import matplotlib.pyplot as plt

from vdif import VDIFHeader

parser = argparse.ArgumentParser()
parser.add_argument('-n', '-nframe', '--nframe', help="Number of frames to read", default=2, type=int)
parser.add_argument('-d', '-diff', '--diff', help="Difference files", action='store_true')
parser.add_argument('-p', '-phase', '--phase', help="Plot Phase", action='store_true')
parser.add_argument('-r', '-real', '--real', help="Plot real data (default amplitude)", action='store_true')
parser.add_argument('-b', '-both', '--both', help="Plot both real and imag data", action='store_true')
parser.add_argument('-i', '-imag', '--imag', help="Plot real data (default amplitude)", action='store_true')
parser.add_argument('-m', '-multiply', '--multiply', help="Multiply channels", action='store_true')
parser.add_argument('-q', '-quotient', '--quotient', help="Divide channels", action='store_true')
parser.add_argument('-f', '-fake', '--fake', help="Generate Fake data also, in MHz", type=float)
parser.add_argument('-bandwidth', '--bandwidth', help="Bandwidth of sampling, in MHz", type=float, default=128)
parser.add_argument('-c', '-channel', '--channel', help="Channel to plot", type=int, default=0)
parser.add_argument('-t', '-time', '--time', help="Plot 'x' axis in uSec", action='store_true')
parser.add_argument('vdiffile', nargs='+')
args = parser.parse_args()

np.set_printoptions(linewidth=100)

channel = args.channel

allData = []

if args.real or args.both:
    myfunc = np.real
elif args.imag:
    myfunc = np.imag
elif args.phase:
    myfunc = np.angle
else:
    myfunc = np.abs

firstFrames = []
for file in args.vdiffile:
    with open(file, "rb") as f:
        header = VDIFHeader(f.read(32))
        firstFrames.append(header)

framesize = firstFrames[0].framelength
iscomplex = firstFrames[0].complex
nbit = firstFrames[0].nbits
nchan = firstFrames[0].nchan

# Figure out if any file had a frame starting "late"

startSec = firstFrames[0].seconds
startFrame = firstFrames[0].frame

if len(firstFrames)>1:
    for h in firstFrames:
        if h.framelength!=framesize:
            print("Error, framesizes do not match")
            sys.exit(1)
        if h.seconds>startSec:
            startSec = h.seconds
            startFrame = h.frame
        elif h.seconds==startSec and h.frame>startFrame:
            startSec = h.seconds
            startFrame = h.frame

print(f"Start Sec = {startSec}, Start frame = {startFrame}")
    
for file in args.vdiffile:
    first = True
    nframe = 0
    with open(file, "rb") as f:
        n = 0
        while n < args.nframe:
            header = VDIFHeader(f.read(32))
            thisdata = header.readFrame(f,docomplex=iscomplex)

            if header.seconds<startSec or (header.seconds==startSec and header.frame<startFrame):
                print(f"Skipping {header.seconds}/{header.frame}")
            else:
                n += 1
                if first:
                    data = thisdata
                    first = False
                    #print(data.shape)
                else:
                    data = np.concatenate((data, thisdata), axis=1)
   
    x = np.arange(len(data[channel,:]))
    if args.time:
        x = x * header.sample_nSec() / 1000.0

    if not (args.diff or args.multiply or args.quotient):
        plt.plot(x, myfunc(data[channel,:]))
        if args.both:
            plt.plot(x, np.imag(data[channel,:]))
    else:
        allData.append(data[channel,:])
       
if args.fake is not None:
    print("Generate Fake")
    num_samples = len(data[channel,:])

    frequency = args.fake*1e6  # Frequency of the sine wave
    sampling_rate =   args.bandwidth*1e6 # Sampling rate (number of samples per second)
    duration = num_samples/sampling_rate  # Duration of the signal in seconds

    # Generate time values
    t = np.linspace(0, duration, num_samples, endpoint=False)

    # Generate the complex sine wave
    sine_fake = (np.cos(2*np.pi * frequency * t) + 1j * np.sin(2*np.pi * frequency * t))*120

    if not (args.diff or args.multiply or args.quotient):
        plt.plot(x, myfunc(sine_fake))

if args.diff:
    for d in allData[1:]:
        if args.phase:
            plt.plot(x, (np.pi + np.angle(d)-np.angle(allData[0])) % (2*np.pi)  - np.pi )
        else:
            plt.plot(x, myfunc(d)-myfunc(allData[0]))
            if args.both:
                plt.plot(x, np.imag(d)-np.imag(allData[0]))
    if args.fake:
        if args.phase:
            plt.plot(x, (np.pi + np.angle(sine_fake)-np.angle(allData[0])) % (2*np.pi)  - np.pi )
        else:
            plt.plot(x, myfunc(sine_fake)-myfunc(allData[0]))
            if args.both:
                plt.plot(x, np.imag(sine_fake)-np.imag(allData[0]))

if args.multiply:
    for d in allData[1:]:
        plt.plot(x, myfunc(d*allData[0].conj()))
        if args.both:
            plt.plot(x, np.imag(d*allData[0].conj()))
    if args.fake:
            plt.plot(x, np.imag(sine_fake*allData[0].conj()))
            if args.both:
                plt.plot(x, np.imag(sine_fake)-np.imag(sine_fake*allData[0].conj()))


if args.quotient:
    print("Quotient")
    for d in allData[1:]:
        plt.plot(x, myfunc(d/allData[0]))
        if args.both:
            plt.plot(x, np.imag(d/allData[0]))
    if args.fake:
        if args.phase:
            plt.plot(x, (np.pi + np.angle(sine_fake/allData[0])) % (2*np.pi)  - np.pi )
        else:
            plt.plot(x, myfunc(sine_fake-allData[0]))
            if args.both:
                plt.plot(x, np.imag(sine_fake/allData[0]))


                
if args.time:
    plt.xlabel("Time (uSec)")
else:
    plt.xlabel("Sample #")

                
plt.show()
             
