#!/usr/bin/env python3

# Read VDIF file and print sample statistics

import numpy as np
import argparse, sys

doComplex = False



nframe = 1000


parser = argparse.ArgumentParser()
#parser.add_argument('-b', '--bits', '-bits', help="Number of bits", type=int, default=16)
#parser.add_argument('-N', '--nchan', '-nchan', help="Number of voltage channels (IFs)", type=int, default=1)
#parser.add_argument('-f', '--nfft', '-nfft', help="Number of FFT to average", type=int, default=100)
#parser.add_argument('-c', '--complex', '-complex', help="Assume complex voltages", action="store_true")
parser.add_argument('vdiffile', help="VDIF file")
args = parser.parse_args()

filename = args.vdiffile

fd = open(filename, 'rb')

# Read first header
header = np.fromfile(fd, dtype=np.uint32, count=8)

framesize = (header[2]&0xFFFFFF)*8
nchan = (header[2]>>24)&0x1F
nchan = 2**nchan
bits =  ((header[3]>>26)&0x1F) + 1
doComplex = (header[3]>>31)&0x1 == 1

datablock = framesize - 32
    
    
if bits==16:
    dataType = np.uint16
    bytespersample = 2
    offset = 32768
elif bits==8:
    dataType = np.uint8
    bytespersample = 1
    offset = 128
else:
    print("Unsupported # bits ({}). Quitting\n".format(bits))
    sys.exit(1)


count = 0
sum = 0.0
sumSqr = 0.0
sumSqrIm = 0.0
n = 0
dataMax = -1000
dataMin = 1000


dataList = ()

while count<nframe:

  data = np.fromfile(fd, dtype = dataType, count = datablock//bytespersample).astype('float32')
  data -= offset
  if doComplex:
    data = data.view(dtype=np.complex64)

  fd.seek(32, 1)

  sum += data.sum()
  if doComplex:
    sumSqr += np.square(data.real).sum()
    sumSqrIm += np.square(data.imag).sum()
  else:
    sumSqr += np.square(data).sum()
  n += len(data)
  dataMax = max(data.max(),dataMax)
  dataMin = min(data.min(), dataMin)

  count += 1

print("Mean = {:.3f}".format(sum/n))
if doComplex:
    print("StdDev = {:.2f}/{:.2f}".format(np.sqrt(sumSqr/n - sum.real*sum.real/(n*n)), np.sqrt(sumSqrIm/n - sum.imag*sum.imag/(n*n))))
else:
    print("StdDev = {:.1f}".format(np.sqrt(sumSqr/n - sum*sum/(n*n))))
print("Max = {}".format(dataMax))
print("Min = {}".format(dataMin))

