#!/usr/bin/env python3

# Read zerocorr .vis file and compute delay between data sets
#
# The visibility output file (specified in line 15 above) has 8 columns:
#   1  Channel (spectral point) number
#   2  Frequency relative to first spectral channel (Hz)
#   3  Real value of the visibility
#   4  Imaginary value of the visibility
#   5  Amplitude
#   6  Phase (rad)
#   7  Autocorrelation of the first datastream (real only)
#   8  Autocorrelation of the second datastream (real only)



import numpy as np
import matplotlib.pyplot as plt
import argparse, sys, math

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--doublesideband', '-doublesideband', help="Double sideband", action="store_true")
parser.add_argument('-p', '--plot', '-plot', help="Plot results", action="store_true")
parser.add_argument('-r', '--robust', '-robust', help="Robust Fit", action="store_true")
parser.add_argument('-f', '--fft', '-fft', help="FFT data to delay spectrum", action="store_true")
parser.add_argument('visfile')


args = parser.parse_args()

def line (x, t, y):
    return x[0] * t + x[1] - y

def modphase(p):
    return np.fmod(np.fmod(p,360.0)+540, 360.0) - 180 # Range -180 -- +180    

def subtractFit(y, x, order=1):
    return(y-fitY)
    
if args.robust:
    from scipy.optimize import least_squares

data=np.genfromtxt(args.visfile, dtype=float)

freq  = data[:,1]/1e6
amp   = data[:,4]
phase = data[:,5]

if args.doublesideband:
    n = len(freq)
    phase = np.roll(phase,n//2)
    amp = np.roll(amp, n//2)

data_complex = amp*np.exp(1j*phase)

phase *=  180/np.pi

if args.fft:
    n = len(freq)
    bandwidth = (freq[1]-freq[0])*len(freq)
    ts = 1.0/bandwidth
    t = (np.arange(n)-n//2)*ts
    f_amp = np.abs(np.fft.fftshift(np.fft.fft(data_complex)))
    imax = np.argmax(f_amp)
    ioff = n//2-imax
    fft_delay = t[imax]
    print("FFT delay= {:.3f} usec".format(fft_delay))

    phase_ramp = (np.arange(n)/n)*360*ioff

    phase += np.fmod(phase_ramp,360.0)
    phase = modphase(phase)
        
#phase_ramp = (np.arange(1024)/1024)*(4.8224901*16*360)
#phase = modphase(phase + phase_ramp)

av_phase = np.median(phase)
phase = modphase(phase-av_phase)

if args.robust:
    x0 = np.ones(2)
    res = least_squares(line, x0, loss='soft_l1', f_scale=0.1, args=(freq, phase))
    fit = res.x
else:    
    fit = np.polyfit(freq,phase,1)
    
poly = np.poly1d(fit)
fitPhase = poly(freq)

deltaPhase = fit[0]*(freq[-1]-freq[0])
delay = -fit[0]/(360)*1e3

phase += av_phase
fitPhase += av_phase

if args.fft:
    delay -= fft_delay*1000

print("dPhase={:.3f} deg".format(deltaPhase))
print("avPhase={:.3f} deg".format(np.mean(phase)))
print("delay={:.4f} nsec".format(delay))
print("Line:  {:.5f}x + {:.4f}".format(fit[0], fit[1]))

if args.plot:
    plt.plot(freq, phase, '.')
    plt.plot(freq, fitPhase)
    plt.show()
