import struct
import numpy as np

class VDIFHeader:

    HeaderSize = 32
    
    def __init__(self, bytes):
        if len(bytes) < 32:
            return None
        (seconds,frame,framelength, antid1, antid2, threadid) = struct.unpack('<IIIssH', bytes[:16])

        # Decode fields
        self.invalid = (seconds>>31) & 0x1
        self.legacy = (seconds>>30) & 0x1
        self.seconds = seconds & 0x3FFFFFF

        self.epoch = (frame>>24) & 0x7F
        self.frame = frame & 0xFFFFFF

        self.version = (framelength>>29) & 0x7
        self.nchan = 1 << ((framelength>>24) & 0x1F)
        self.framelength = (framelength & 0xFFFFFF)*8
        self.datalength = self.framelength-32

        self.complex = (threadid >> 15) & 0x1
        self.nbits = ((threadid >> 10) & 0x1F) + 1
        self.threadid = threadid & 0x3FF
        self.stationid = (antid2+antid1).decode();

    def completeSample(self):
        # Number of bits in a complete sample
        cs = self.nchan*self.nbits
        if self.complex: cs *= 2
        return cs
    
    def completeSampleBytes(self):
        # Number of bytes in a complete sample
        cs = self.completeSample()
        if cs%8:  # Not an exact number of bytes
           raise ValueError("Complete Sample not an exact number of bytes")
        return cs//8

    def samplesPerFrame(self):
        samplesperframe = (self.datalength*8) // (self.nchan * self.nbits)
        if self.complex:
            samplesperframe = samplesperframe // 2
        return samplesperframe
        
    def readFrame(self, f, docomplex=False):
        if self.nbits==16:
            thisType = np.uint16
        elif self.nbits==8:
            thisType = np.uint8
        else:
            raise ValueError("Do not support nbits={}".format(self.nbits))
        
        data = np.fromfile(f, dtype=thisType, count=self.datalength//(self.nbits//8)).astype(float)

        if self.nbits==8:
            data -= 128
        elif self.nbits==16:
            data -= 32768;

        if self.complex:        
            data_real = data[0::2].reshape(-1,self.nchan).transpose(1,0)
            data_imag = data[1::2].reshape(-1,self.nchan).transpose(1,0)

            if docomplex:
                data_complex = np.array(data_real, dtype=complex)
                data_complex.imag = data_imag
                return data_complex
            else:
                return(data_real, data_imag)

        else:
            return data.reshape(-1,self.nchan).transpose(1,0)
        
    def next(self, incr=1):
        self.frame += incr
        frameperperiod = self.framesPerPeriod()
        
        while self.frame >= frameperperiod:
            self.seconds += self.period
            self.frame -= frameperperiod

    def __str__(self):
        return '''FRAME#:    {}
SECONDS: {}

EPOCH:       {}
NBITS:       {}
INVALID:     {}
COMPLEX:     {}
VERSION:     {}

THREADID:    {}
STATIONID:   {}

NCHAN:       {}
DATALENGTH:  {}'''.format(self.frame, self.seconds, self.epoch, self.nbits, self.invalid, self.complex,
                              self.version, self.threadid,  self.stationid, self.nchan,
                              self.framelength)

#def deltaFrame(header1, header2):
#    return round((header2.seconds-header1.seconds)/header1.period*header1.framesPerPeriod()+(header2.frame-header1.frame))
