#!/usr/bin/env python3

import sys, struct, argparse
import numpy as np


class CodifHeader:
    def __init__(self, bytes):

        (frame, seconds,epoch, nbits, small, period, threadid, groupid, secondaryid, antid1, antid2, nchan, sampleblocklength,
             framelength, numsamples, sync) = struct.unpack('<IIBBHxxHHHHssHHIQI', bytes[:44])

        # Small fields
        self.power = small & 0x1;
        self.invalid = (small>>1) & 0x1;
        self.complex = (small>>2) & 0x1;
        self.calon = (small>>3) & 0x1;
        self.representation = (small>>4) & 0xF;
        self.version = (small>>8) & 0x1F;
        self.protocol = (small>>13) & 0x7;
         
        self.frame = frame;
        self.seconds = seconds;
        self.epoch = epoch;
        self.nbits = nbits;
        self.period = period;
        self.threadid = threadid;
        self.groupid = groupid;
        self.secondaryid = secondaryid;
        self.stationid = (antid2+antid1).decode();
        self.nchan = nchan;
        self.sampleblocklength = sampleblocklength;
        self.framelength = framelength*8;
        self.numsamples = numsamples;
        self.sync = sync;

    def completeSample(self):
        # Number of bits in a complete sample
        cs = self.nchan*self.nbits
        if self.complex: cs *= 2
        return cs
    
    def completeSampleBytes(self):
        # Number of bytes in a complete sample
        cs = self.completeSample()
        if cs%8:  # Not an exact number of bytes
            raise ValueError("Complete Sample not an exact number of bytes")
        return cs//8

    def readFrame(self, f, complex=False):

        if self.nbits==16:
            thisType = np.short
        elif self.nbits==8:
            thisType = np.byte
        else:
            raise ValueError("Do not support nbits={}".format(nbit))
        
        data = np.fromfile(f, dtype=thisType, count=self.framelength//(self.nbits//8)).astype(float)

        if self.complex:
        
            data_real = data[0::2].reshape(-1,self.nchan).transpose(1,0)
            data_imag = data[1::2].reshape(-1,self.nchan).transpose(1,0)

            if complex:
                data_complex = np.array(data_real, dtype=complex)
                data_complex.imag = data_imag
                return data_complex
            else:
                return(data_real, data_imag)

        else:
            return data.reshape(-1,self.nchan).transpose(1,0)
        
    def print(self):
        print(
'''FRAME#:    {}
SECONDS: {}

EPOCH:       {}
NBITS:       {}
POWER:       {}
INVALID:     {}
COMPLEX:     {}
CALON:       {}
REPR:        {}
VERSION:     {}
PROTOCOL:    {}
PERIOD:      {}

THREADID:    {}
GROUPID:     {}
SECONDARYID: {}
STATIONID:   {}

NCHAN:       {}
SAMPLEBLOCK: {}
DATALENGTH:  {}

#SAMPLES:    {}

SYNC:        0x{:X}'''.format(header.frame, header.seconds, header.epoch, header.nbits, header.power, header.invalid, header.complex,
                              header.calon, header.representation, header.version, header.protocol, header.period, header.threadid,
                              header.groupid, header.secondaryid, header.stationid, header.nchan, header.sampleblocklength,
                              header.framelength, header.numsamples, header.sync))


parser = argparse.ArgumentParser()
parser.add_argument('-n', '-nframe', '--nframe', help="Number of frames to read", default=1, type=int)
parser.add_argument('codiffile', nargs='+')
args = parser.parse_args()

# Histogram 4 bit complex data
def frameHisto(frame, nchan):
    histo = np.zeros((nchan*2,16), dtype=int)
    nbyte = len(frame)
    assert (nbyte % nchan)==0, "Frame size must be multiple of # channels"

    for i in range(nbyte//nchan):
        for j in range(nchan):
            ireal = frame[i*nchan+j]&0xF
            icomplex = frame[i*nchan+j]&0xF
            histo[j*2][ireal] +=1
            histo[j*2+1][icomplex] +=1

    return(histo)

np.set_printoptions(linewidth=200)


channel = 0

for file in args.codiffile:

    nframe = 0
    with open(file, "rb") as f:
        header = CodifHeader(f.read(64))
        f.seek(0) # Rewind

        framesize = header.framelength
        iscomplex = header.complex
        nbit = header.nbits
        nchan = header.nchan

            
        completeSample = header.completeSampleBytes()
        samplesperFrame = framesize//completeSample

        print("*********************************")
        print(file)
        print("ComplexSample= ", completeSample, "   samplesPerFrame=", samplesperFrame)

        sumMean_real = np.zeros(nchan)
        sumMean_imag = np.zeros(nchan)
        sumSqr_real = np.zeros(nchan)
        sumSqr_imag = np.zeros(nchan)

        n = 0
        for i in range(args.nframe):
            header = CodifHeader(f.read(64))
            if header.sync != 0xFEEDCAFE:
                print("Error: lost sync")
                sys.exit(1);

            if iscomplex:
                data_real, data_imag = header.readFrame(f)
            else:
                data_real = header.readFrame(f)

            sumMean_real = data_real.mean(axis=1)
            sumSqr_real += np.sum(data_real**2, axis=1)
            if iscomplex:
                sumMean_imag = data_imag.mean(axis=1)
                sumSqr_imag += np.sum(data_imag**2,axis=1)

            n += 1

        sumMean_real /= n*samplesperFrame
        sumMean_imag /= n*samplesperFrame

        print("MeanReal = [", ", ".join(f"{x:.2f}" for x in sumMean_real), "]")
        if iscomplex: print("MeanImag = [", ", ".join(f"{x:.2f}" for x in sumMean_imag), "]")
#        if iscomplex: print("MeanImag = ", sumMean_imag)

        print("StdReal = [", ", ".join(f"{x:.2f}" for x in np.sqrt(sumSqr_real/(n*samplesperFrame)-sumMean_real**2)), "]")
        if iscomplex: print("StdReal = [", ", ".join(f"{x:.2f}" for x in np.sqrt(sumSqr_imag/(n*samplesperFrame)-sumMean_imag**2)), "]")
        #print("StdReal = ", np.sqrt(sumSqr_real/(n*samplesperFrame)-sumMean_real**2))
        #if iscomplex: print("StdImag = ", np.sqrt(sumSqr_imag/(n*samplesperFrame)-sumMean_imag**2))
            
            

#        while nframe<maxFrame:
#            data = f.read(framelength)
#            if not data: break
#            nframe += 1
#            histo += frameHisto(data, nchan)

#            h = f.read(64)
#            if not h: break

#    shuffleIdx = list(range(8,16))+list(range(8))
#    for n in range(-8,8): print("{:5d} ".format(n),end="")
#    print()
    
#    for r in histo[:,shuffleIdx]:
#        sum = r.sum()
#        norm = r.astype(float)/sum*100
#        for n in norm: print(" {:5.2f}".format(n),end="")
#        print()
