#!/usr/bin/env python3

import sys, struct, argparse
import numpy as np
import matplotlib.pyplot as plt

class CodifHeader:
    def __init__(self, bytes):

        (frame, seconds,epoch, nbits, small, period, threadid, groupid, secondaryid, antid1, antid2, nchan, sampleblocklength,
             framelength, numsamples, sync) = struct.unpack('<IIBBHxxHHHHssHHIQI', bytes[:44])

        # Small fields
        self.power = small & 0x1;
        self.invalid = (small>>1) & 0x1;
        self.complex = (small>>2) & 0x1;
        self.calon = (small>>3) & 0x1;
        self.representation = (small>>4) & 0xF;
        self.version = (small>>8) & 0x1F;
        self.protocol = (small>>13) & 0x7;
         
        self.frame = frame;
        self.seconds = seconds;
        self.epoch = epoch;
        self.nbits = nbits;
        self.period = period;
        self.threadid = threadid;
        self.groupid = groupid;
        self.secondaryid = secondaryid;
        self.stationid = (antid2+antid1).decode();
        self.nchan = nchan;
        self.sampleblocklength = sampleblocklength;
        self.framelength = framelength*8;
        self.numsamples = numsamples;
        self.sync = sync;

    def completeSample(self):
        # Number of bits in a complete sample
        cs = self.nchan*self.nbits
        if self.complex: cs *= 2
        return cs
    
    def completeSampleBytes(self):
        # Number of bytes in a complete sample
        cs = self.completeSample()
        if cs%8:  # Not an exact number of bytes
            raise ValueError("Complete Sample not an exact number of bytes")
        return cs//8

    def readFrame(self, f, docomplex=False):

        if self.nbits==16:
            thisType = np.short
        elif self.nbits==8:
            thisType = np.byte
        else:
            raise ValueError("Do not support nbits={}".format(nbit))
        
        data = np.fromfile(f, dtype=thisType, count=self.framelength//(self.nbits//8)).astype(float)

        if self.complex:
        
            data_real = data[0::2].reshape(-1,self.nchan).transpose(1,0)
            data_imag = data[1::2].reshape(-1,self.nchan).transpose(1,0)

            if docomplex:
                data_complex = np.array(data_real, dtype=complex)
                data_complex.imag = data_imag
                return data_complex
            else:
                return(data_real, data_imag)

        else:
            return data.reshape(-1,self.nchan).transpose(1,0)
        
    def print(self):
        print(
'''FRAME#:    {}
SECONDS: {}

EPOCH:       {}
NBITS:       {}
POWER:       {}
INVALID:     {}
COMPLEX:     {}
CALON:       {}
REPR:        {}
VERSION:     {}
PROTOCOL:    {}
PERIOD:      {}

THREADID:    {}
GROUPID:     {}
SECONDARYID: {}
STATIONID:   {}

NCHAN:       {}
SAMPLEBLOCK: {}
DATALENGTH:  {}

#SAMPLES:    {}

SYNC:        0x{:X}'''.format(header.frame, header.seconds, header.epoch, header.nbits, header.power, header.invalid, header.complex,
                              header.calon, header.representation, header.version, header.protocol, header.period, header.threadid,
                              header.groupid, header.secondaryid, header.stationid, header.nchan, header.sampleblocklength,
                              header.framelength, header.numsamples, header.sync))


parser = argparse.ArgumentParser()
parser.add_argument('-n', '-nframe', '--nframe', help="Number of frames to read", default=2, type=int)
parser.add_argument('-d', '-diff', '--diff', help="Difference files", action='store_true')
parser.add_argument('-p', '-phase', '--phase', help="Plot Phase", action='store_true')
parser.add_argument('-r', '-real', '--real', help="Plot real data (default amplitude)", action='store_true')
parser.add_argument('-b', '-both', '--both', help="Plot both real and imag data", action='store_true')
parser.add_argument('-i', '-imag', '--imag', help="Plot real data (default amplitude)", action='store_true')
parser.add_argument('-m', '-multiply', '--multiply', help="Multiply channels", action='store_true')
parser.add_argument('-f', '-fake', '--fake', help="Generate Fake data also", action='store_true')
parser.add_argument('-c', '-channel', '--channel', help="Channel to plot", type=int, default=0)
parser.add_argument('codiffile', nargs='+')
args = parser.parse_args()

np.set_printoptions(linewidth=100)

channel = args.channel

allData = []

if args.real or args.both:
    myfunc = np.real
elif args.imag:
    myfunc = np.imag
elif args.phase:
    myfunc = np.angle
else:
    myfunc = np.abs

firstFrames = []
for file in args.codiffile:
    with open(file, "rb") as f:
        header = CodifHeader(f.read(64))
        firstFrames.append(header)

framesize = firstFrames[0].framelength
iscomplex = firstFrames[0].complex
nbit = firstFrames[0].nbits
nchan = firstFrames[0].nchan

completeSample = firstFrames[0].completeSampleBytes()
samplesperFrame = framesize//completeSample

# Figure out if any file had a frame starting "late"

startSec = firstFrames[0].seconds
startFrame = firstFrames[0].frame

if len(firstFrames)>1:
    for h in firstFrames:
        if h.framelength!=framesize:
            print("Error, framesizes do not match")
            sys.exit(1)
        if h.seconds>startSec:
            startSec = h.seconds
            startFrame = h.frame
        elif h.seconds==startSec and h.frame>startFrame:
            startSec = h.seconds
            startFrame = h.frame

print(f"Start Sec = {startSec}, Start frame = {startFrame}")
    
for file in args.codiffile:
    first = True
    nframe = 0
    with open(file, "rb") as f:
        n = 0
        while n < args.nframe:
            header = CodifHeader(f.read(64))
            if header.sync != 0xFEEDCAFE:
                print("Error: lost sync")
                sys.exit(1);
            thisdata = header.readFrame(f,docomplex=iscomplex)

            if header.seconds<startSec or (header.seconds==startSec and header.frame<startFrame):
                print(f"Skipping {header.seconds}/{header.frame}")
            else:
                n += 1
                if first:
                    data = thisdata
                    first = False
                    print(data.shape)
                else:
                    data = np.concatenate((data, thisdata), axis=1)

    if not (args.diff or args.multiply):
        plt.plot(myfunc(data[channel,:]))
        if args.both:
            plt.plot(np.imag(data[channel,:]))
    else:
        allData.append(data[channel,:])

if args.fake:
    print("Generate Fake")
    num_samples = len(data[channel,:])

    frequency = 0.05*1e6  # Frequency of the sine wave
    sampling_rate =   1.6e6*32.0/27.0 # Sampling rate (number of samples per second)
    duration = num_samples/sampling_rate  # Duration of the signal in seconds

    # Generate time values
    t = np.linspace(0, duration, num_samples, endpoint=False)

    # Generate the complex sine wave
    sine_fake = (np.sin(2*np.pi * frequency * t) + 1j * np.cos(2*np.pi * frequency * t))*11680

    if not (args.diff or args.multiply):
        plt.plot(myfunc(sine_fake))


        
if args.diff:
    for d in allData[1:]:
        if args.phase:
            plt.plot((np.pi + np.angle(d)-np.angle(allData[0])) % (2*np.pi)  - np.pi )
        else:
            plt.plot(myfunc(d)-myfunc(allData[0]))
            if args.both:
                plt.plot(np.imag(d)-np.imag(allData[0]))
    if args.fake:
        if args.phase:
            plt.plot((np.pi + np.angle(sine_fake)-np.angle(allData[0])) % (2*np.pi)  - np.pi )
        else:
            plt.plot(myfunc(sine_fake)-myfunc(allData[0]))
            if args.both:
                plt.plot(np.imag(sine_fake)-np.imag(allData[0]))


if args.multiply:
    for d in allData[1:]:
        plt.plot(myfunc(d*allData[0].conj()))
        if args.both:
            plt.plot(np.imag(d*allData[0].conj()))
    if args.fake:
            plt.plot(np.imag(sine_fake*allData[0].conj()))
            if args.both:
                plt.plot(np.imag(sine_fake)-np.imag(sine_fake*allData[0].conj()))

plt.show()
             
