#!/usr/bin/env python

#**************************************************************************
#   Copyright (C) 2014 by Walter Brisken                                  *
#                                                                         *
#   This program is free software; you can redistribute it and/or modify  *
#   it under the terms of the GNU General Public License as published by  *
#   the Free Software Foundation; either version 3 of the License, or     *
#   (at your option) any later version.                                   *
#                                                                         *
#   This program is distributed in the hope that it will be useful,       *
#   but WITHOUT ANY WARRANTY; without even the implied warranty of        *
#   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         *
#   GNU General Public License for more details.                          *
#                                                                         *
#   You should have received a copy of the GNU General Public License     *
#   along with this program; if not, write to the                         *
#   Free Software Foundation, Inc.,                                       *
#   59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.             *
#**************************************************************************


from sys import argv, exit
from string import split, strip
from math import *
from os import popen
import gzip
from pylab import np, plot, show, scatter, xlabel, ylabel, title, axes, savefig

program = 'plotpcal2'
author = 'Walter Brisken <wbrisken@nrao.edu>'
version = '0.2'

legalOps = ['amp', 'phase', 'delay', 'xy']

def usage(prog):
	print '\n%s ver. %s  %s\n' % (program, version, author)
	print 'Usage: %s [options] operation toneSelections dataFiles\n' % prog
	print 'Some more details are at: http://cira.ivec.org/dokuwiki/doku.php/difx/plotpcal\n'

class Tone:
	def __init__(self, recBand, toneNum, freq, re, im):
		self.recBand = recBand
		self.toneNum = toneNum
		self.freq = freq
		self.re = re
		self.im = im
	def amp(self):
		return sqrt(self.re*self.re + self.im*self.im)
	def phase(self):
		return atan2(self.im, self.re)

class Record:	# all tones for one time from one file
	def __init__(self, mjd, tones):
		self.mjd = mjd
		self.tones = tones
	
	def getRecBands(self):
		rbs = []
	
	def getFreqs(self):
		freqs = []
		for t in self.tones:
			freqs.append(t.freq)
		return freqs


class File:
	def __init__(self, filename):
		s = split(strip(filename), ':')
		self.filename = s[-1]
		if len(s) > 1:
			self.antenna = s[0]
		else:
			self.antenna = None
		self.pcalVersion = None	# file version
		self.records = []
		self.mjdStart = 0.0
		self.mjdStop = 0.0
		self.load()
		self.index()
		self.section = '0'

	def index(self):
		self.tonesInRecBand = {}	# number of tones in recBand of given index
		for r in self.records[0].tones:
			if r.recBand in self.tonesInRecBand:
				self.tonesInRecBand[r.recBand] += 1
			else:
				self.tonesInRecBand[r.recBand] = 1
		self.tonesInRecBand.keys().sort()
		self.nRecBand = len(self.tonesInRecBand)
		self.totalTones = len(self.records[0].tones)
		self.mjdStart = self.records[0].mjd
		self.mjdStop = self.records[-1].mjd

	def split(self, maxGapSec):
		maxGapDay = maxGapSec/86400.0
		breaks = []
		freqs = self.records[0].getFreqs()
                lastmjd = self.records[0].mjd
		for r in range(1, len(self.records)):
			fq = self.records[r].getFreqs()
                        mjd = self.records[r].mjd
			if fq != freqs or mjd > lastmjd + maxGapDay:
				breaks.append(r)
				freqs = fq
			lastmjd = mjd
		breaks.append(len(self.records))
		if len(breaks) == 1:
			return [self]
		else:
			sections = []
			s = 0
			secNum = 0
			for b in breaks:
				nf = File(self.filename)
				nf.antenna = self.antenna
				nf.pcalVersion = self.pcalVersion
				nf.records = self.records[s:b]
				nf.index()
				nf.section = self.section + '.%d' % secNum
				sections.append(nf)
				s = b
				secNum += 1
			return sections

	def quack(self, skip):
		self.records = self.records[skip:]
		self.index()

	def average(self, aveSec):
		pass

	def __lt__(self, other):
		return self.mjdStart < other.mjdStart

	def summary(self):
		print 'Data file:'
		print '  File = %s' % self.filename
		print '  Section = %s' % self.section
		if self.antenna != None:
			print '  Antenna = %s' % self.antenna
		else:
			print '  No antenna specified'
		if self.pcalVersion != None:
			print '  Pcal file version = %d' % self.pcalVersion
		else:
			print '  Weird: no Pcal version determined!'
		print '  Num records = %d' % len(self.records)
		print '  Time range: %14.8f to %14.8f' % (self.mjdStart, self.mjdStop)
		print '  Tone counts: ', self.tonesInRecBand
	
	def load(self):
		# here choose how to load based on extension
		if self.filename[-3:] == '.gz':
			raw = gzip.open(self.filename).readlines()
		elif self.filename[-3:] == '.xz':
			raw = popen('xzcat %s' % self.filename).readlines()
		else:
			raw = open(self.filename).readlines()
		extracts = {}	# dictionary; text of MJD is key
		for r in raw:
			s = split(strip(r))
			if len(s) < 1:
				continue
			if self.pcalVersion == None:
				if r[:17] == '# File version = ':
					self.pcalVersion = int(r[17:])
					mjdColumn = 1
					extraColumns = 6
					columnsPerTone = 4
					dbeColumn = -1
					recBandColumn = -1
					freqColumn = 0
					toneColumn = 2
				elif r[0] != '#':
					if len(s[0]) > 5:
						# not versioned / from Xcube
						self.pcalVersion = -1
						mjdColumn = 0
						extraColumns = 4
						columnsPerTone = 7
						recBandColumn = 3
						dbeColumn = 2
						freqColumn = 6
						toneColumn = 0
					else:
						self.pcalVersion = 0
						mjdColumn = 1
						extraColumns = 9
						columnsPerTone = 4
						dbeColumn = -1
						recBandColumn = 0
						freqColumn = 1
						toneColumn = 2
			if r[0] == '#':
				continue
			if self.pcalVersion != -1 and self.antenna == None:
				self.antenna = s[0]

			if self.pcalVersion != None:
				if s[0] in extracts:
					record = extracts[s[mjdColumn]]
				else:
					extracts[s[mjdColumn]] = []
					record = extracts[s[mjdColumn]]
				mjd = float(s[mjdColumn])
				totalTones = (len(s)-extraColumns)/columnsPerTone
				lastRB = -1
				toneNum = 0
				if self.pcalVersion == 1:
					tonesPerRecBand = int(s[5])
				for t in range(totalTones):
					i = t*columnsPerTone + extraColumns
					if self.pcalVersion == 1:
						rb = t/tonesPerRecBand
					else:
						rb = int(s[i+recBandColumn])
						if dbeColumn >= 0:
							rb += 100*int(s[i+dbeColumn])
					if lastRB != rb:
						lastRB = rb
						toneNum = 0
					freq = int(s[i+freqColumn])
					if rb >= 0 and freq >= 0:
						record.append(Tone(rb, toneNum, freq, float(s[i+toneColumn]), float(s[i+toneColumn+1])))
						toneNum += 1
					
		keys = extracts.keys()
		keys.sort()
		for k in keys:
			self.records.append(Record(float(k), extracts[k]))

# returns: op, toneSelect, files, options
def parseCommandLine():
	op = None
	toneSelect = []
	fileList = []
	options = {}
	options['verbose'] = 1
	options['subday'] = 1
	options['minsectionsize'] = 10
	options['maxgap'] = 10 # seconds
	options['average'] = 6 # seconds
	options['timeunits'] = 'Hours'
	options['quack'] = 1
	for a in argv[1:]:
		if a[0] == '-':
			# this is an option
			if a in ['-h', '--help']:
				usage(argv[0])
				exit(0)
			elif a in ['-v', '--verbose']:
				options['verbose'] += 1
			elif a in ['-q', '--quiet']:
				options['verbose'] -= 1
			elif a in ['-d1']:
				options['subday'] = 1
			elif a in ['-d2']:
				options['subday'] = 2
			else:
				if len(a) > 2 and a[1] == '-':
					options[a[2:]] = 1
				else:
					options[a[1:]] = 1
		elif '=' in a:	
			# this is a tone selector
			toneSelect.append(a)
		elif op == None:
			op = a
		else:
			fileList.append(a)
	
	die = False
	if len(fileList) == 0:
		print 'Error: no data files provided'
		die = True
	if len(toneSelect) == 0:
		print 'Error: no tone selection'
		die = True
	if op == None:
		print 'Error: operation not specified'
		die = True
	elif not op in legalOps:
		print 'Error: operation must be one of ', legalOps
		die = True

	if die:
		print '\nRun with --help for more info\n'
		exit(0)

	if options['verbose'] > 1:
		print 'op = %s' % op
		print 'toneSection =', toneSelect
		print 'data files =', fileList
		print 'options =', options

	return op, toneSelect, fileList, options

def antList(fileData):
	antennas = []
	for f in fileData:
		if not f.antenna in antennas:
			antennas.append(f.antenna)
	antennas.sort()
	return antennas

def getAntennaData(fileData, A):
	D = []
	for f in fileData:
		if f.antenna == A:
			D.append(f)
	D.sort()
	return D

def calculateAmplitudes(plotData, row, rec, toneColumns):
	plotData[0][row] = rec.mjd
	for i in range(len(toneColumns)):
		plotData[i+1][row] = rec.tones[toneColumns[i]].amp()

def calculatePhases(plotData, row, rec, toneColumns):
	plotData[0][row] = rec.mjd
	for i in range(len(toneColumns)):
		plotData[i+1][row] = rec.tones[toneColumns[i]].phase()

def calculateComplex(plotDataX, plotDataY, row, rec, toneColumns):
	for i in range(len(toneColumns)):
		plotDataX[i][row] = rec.tones[toneColumns[i]].re
		plotDataY[i][row] = rec.tones[toneColumns[i]].im

# in units of picoseconds
def calculateDelays(plotData, row, rec, toneColumns):
	plotData[0][row] = rec.mjd
	for i in range(len(toneColumns)):
		tone1 = rec.tones[toneColumns[i][0]]
		tone2 = rec.tones[toneColumns[i][1]]
		df = tone1.freq-tone2.freq
		plotData[i+1][row] = 1.0e6*(tone1.phase()-tone2.phase())/(2*pi*df)

def unwrapPhases(phaseData):
	lastPhase = phaseData[0]
	for i in range(1, len(phaseData)):
		while phaseData[i] > lastPhase+pi:
			phaseData[i] -= 2*pi
		while phaseData[i] < lastPhase-pi:
			phaseData[i] += 2*pi
		lastPhase = phaseData[i]

def unwrapDelays(delayData, ambig):
	lastDelay = 0
	for i in range(len(delayData)):
		while delayData[i] > lastDelay + ambig/2:
			delayData[i] -= ambig
		while delayData[i] < lastDelay - ambig/2:
			delayData[i] += ambig
		lastDelay = delayData[i]

def getOneToneColumns(toneSelect, d):
	columns = []
	record = d.records[0]
	for ts in toneSelect:
		tss = split(ts, '=')
		if tss[0] == 'rt':
			# look for specific tone number.  Should be unambiguous
			cd = split(tss[1], ',')	# column identifying data
			rbi = int(cd[0])	# rec band (0 based)
			if rbi < 0:
				rbi += d.nRecBand
			if rbi < 0 or rbi >= d.nRecBand:
				print 'selection %s is invalid for %s section %s.  Skipping.' % (ts, d.filename, d.section)
				continue
			rb = d.tonesInRecBand.keys()[rbi]
			tn = int(cd[1])		# tone num
			if tn < 0:
				# negative number means count from back, python style
				tn += d.tonesInRecBand[rbi]
			if tn < 0 or tn >= d.tonesInRecBand[rbi]:
				print 'selection %s is invalid for %s section %s.  Skipping.' % (ts, d.filename, d.section)
				continue
			for t in range(d.totalTones):
				tone = record.tones[t]
				if tone.recBand == rb and tone.toneNum == tn:
					columns.append(t)
		elif tss[0] == 'rf':
			# look for frequency within a specific record band.  Should be unambiguous
			cd = split(tss[1], ',')	# column identifying data
			rbi = int(cd[0])	# rec band (0 based)
			if rbi < 0:
				rbi += d.nRecBand
			if rbi < 0 or rbi >= d.nRecBand:
				print 'selection %s is invalid for %s section %s.  Skipping.' % (ts, d.filename, d.section)
				continue
			rb = d.tonesInRecBand.keys()[rbi]
			if cd[1][0] == '%':		# freq modulo this amount, not absolute freq
				fq = int(cd[1][1:])	# freq (MHz)
				for t in range(d.totalTones):
					tone = record.tones[t]
					if tone.recBand == rb and tone.freq % fq == 0:
						columns.append(t)
			else:				# absolute frequency
				fq = int(cd[1])		# freq (MHz)
				for t in range(d.totalTones):
					tone = record.tones[t]
					if tone.recBand == rb and tone.freq ==fq:
						columns.append(t)
		elif tss[0] == 'f':
			# look for specific frequency.  There could be many
			cd = split(tss[1], ',')	# column identifying data
			for c in cd:
				if c[0] == '%':		# freq modulo this amount, not absolute freq
					fq = int(c[1:])	# freq (MHz)
					for t in range(d.totalTones):
						tone = record.tones[t]
						if tone.freq % fq == 0:
							columns.append(t)
				else:				# absolute frequency
					fq = int(c)		# freq (MHz)
					for t in range(d.totalTones):
						tone = record.tones[t]
						if tone.freq == fq:
							columns.append(t)
		elif tss[0] == 'r':
			# look for all tones in given record band
			cd = split(tss[1], ',')	# column identifying data
			rbList = []
			for c in cd:
				rbi = int(c)	# rec band (0 based)
				if rbi < 0:
					rbi += d.nRecBand
				if rbi < 0 or rbi >= d.nRecBand:
					continue
				rb = d.tonesInRecBand.keys()[rbi]
				rbList.append(rb)
			if len(rbList) == 0:
				print 'selection %s is invalid for %s sec. %s.  Skipping.' % (ts, d.filename, d.section)
				continue
			for t in range(d.totalTones):
				tone = record.tones[t]
				if tone.recBand in rbList:
					columns.append(t)
		elif tss[0] == 't':
			# look for specific tone number.  Should be unambiguous
			cd = split(tss[1], ',')	# column identifying data
			tnList = []
			for c in cd:
				tnList.append(int(c))		# tone num

			for t in range(d.totalTones):
				tone = record.tones[t]
				if tone.toneNum in tnList or tone.toneNum-d.tonesInRecBand[tone.recBand] in tnList:
					columns.append(t)
			
	return columns
		
def getTwoToneColumns(toneSelect, d):
	columns = []
	record = d.records[0]
	for ts in toneSelect:
		tss = split(ts, '=')
		if tss[0] == 'rt':
			# look for specific tone number.  Should be unambiguous
			cd = split(tss[1], ',')	# column identifying data
			rbi = int(cd[0])	# rec band (0 based)
			if rbi < 0:
				rbi += d.nRecBand
			if rbi < 0 or rbi >= d.nRecBand:
				print 'selection %s is invalid for %s section %s.  Skipping.' % (ts, d.filename, d.section)
				continue
			rb = d.tonesInRecBand.keys()[rbi]
			nTone = d.tonesInRecBand[rb]
			if nTone < 2:
				print 'selection %s is invalid for %s section %s.  Too few tones.  Skipping.' % (ts, d.filename, d.section)
				continue
			if cd[1] == 'pairs':
				for t in range(d.totalTones, 2):
					tone1 = record.tones[t]
					tone2 = record.tones[t+1]
					if tone1.recBand == rb and tone2.recBand == rb:
						columns.append([t,t+1])
			else:
				if cd[1] == 'ends':
					tn1 = 0
					tn2 = nTone-1
				elif cd[1] == 'good':
					if nTone <= 4:
						tn1 = 0
						tn2 = nTone-1
					elif nTone < 8:
						tn1 = 1
						tn2 = nTone-2
					else:
						tn1 = nTone/8
						tn2 = nTone-1-tn1
				elif cd[1] == 'mid':
					tn1 = (nTone-1)/2
					tn2 = tn1+1
				else:
					tn1 = int(cd[1])
					tn2 = int(cd[2])
				if tn1 < 0:
					tn1 += nTone
				if tn2 < 0:
					tn2 += nTone
				tns = []
				for t in range(d.totalTones):
					tone = record.tones[t]
					if tone.recBand == rb:
						if tone.toneNum in [tn1, tn2]:
							tns.append(t)
				if len(tns) == 2:
					columns.append(tns)
		elif tss[0] == 't':
			# look for specific tone number.  Should be unambiguous
			cd = split(tss[1], ',')	# column identifying data
			for rbi in range(d.nRecBand):
				rb = d.tonesInRecBand.keys()[rbi]
				nTone = d.tonesInRecBand[rb]
				if nTone < 2:
					print 'selection %s is invalid for %s section %s.  Too few tones.  Skipping.' % (ts, d.filename, d.section)
					continue
				if cd[0] == 'pairs':
					for t in range(0, d.totalTones, 2):
						tone1 = record.tones[t]
						tone2 = record.tones[t+1]
						if tone1.recBand == rb and tone2.recBand == rb:
							columns.append([t,t+1])
				else:
					if cd[0] == 'ends':
						tn1 = 0
						tn2 = nTone-1
					elif cd[0] == 'good':
						if nTone < 4:
							tn1 = 0
							tn2 = nTone-1
						elif nTone < 8:
							tn1 = 1
							tn2 = nTone-2
						else:
							tn1 = nTone/8
							tn2 = nTone-1-tn1
					else:
						tn1 = int(cd[0])
						tn2 = int(cd[1])
					if tn1 < 0:
						tn1 += nTone
					if tn2 < 0:
						tn2 += nTone
					tns = []
					for t in range(d.totalTones):
						tone = record.tones[t]
						if tone.recBand == rb:
							if tone.toneNum in [tn1, tn2]:
								tns.append(t)
					if len(tns) == 2:
						columns.append(tns)

	return columns

def makeTimeAxis(options, plotData, nRow, start):
	if options['subday'] > 0:
		factors = {'Days':1.0, 'Hours':24.0, 'Minutes':1440.0, 'Seconds':86400.0}
		factor = factors[options['timeunits']]
		if options['subday'] == 1:
			d0 = int(start)
		elif options['subday'] == 2:
			d0 = start
		for row in range(nRow):
			plotData[0][row] = (plotData[0][row]-d0)*factor
		xlabel('Time (%s since MJD %14.8f)' % (options['timeunits'], d0))
	else:
		xlabel('Time (Day of Year)')

def generatePlots(op, toneSelect, data, options):
	colors = ['red', 'orange', 'green', 'blue', 'purple', 'brown', 'cyan', 'magenta']

	start = data[0].mjdStart
	stop = data[-1].mjdStop

	antenna = data[0].antenna

	title('Pulse cal %s for antenna %s' % (op, antenna))

	nPoints = 0

	for d in data:
		if len(d.records) < options['minsectionsize']:
			print 'skipping %s sec. %s because it has only %d records' % (d.filename, d.section, len(d.records))
			continue
		if op == 'amp':
			toneColumns = getOneToneColumns(toneSelect, d)
			if options['verbose'] > 0:
				print 'From %s, the following tone columns will be used:' % d.filename, toneColumns
			nColumn = 1+len(toneColumns)
			nRow = len(d.records)
			nPoints += nRow * len(toneColumns)
			plotData = np.zeros((nColumn, nRow))
			for row in range(nRow):
				calculateAmplitudes(plotData, row, d.records[row], toneColumns)

			makeTimeAxis(options, plotData, nRow, start)
			ylabel('Amplitude')
			for i in range(1, nColumn):
				if 'lines' in options:
					plot(plotData[0], plotData[i], color=colors[(i-1)%len(colors)], linewidth=1)
				else:
					scatter(plotData[0], plotData[i], color=colors[(i-1)%len(colors)], s=3)
		elif op == 'phase':
			toneColumns = getOneToneColumns(toneSelect, d)
			if options['verbose'] > 0:
				print 'From %s, the following tone columns will be used:' % d.filename, toneColumns
			nColumn = 1+len(toneColumns)
			nRow = len(d.records)
			plotData = np.zeros((nColumn, nRow))
			nPoints += nRow * len(toneColumns)
			for row in range(nRow):
				calculatePhases(plotData, row, d.records[row], toneColumns)

			for i in range(1, nColumn):
				unwrapPhases(plotData[i])

			makeTimeAxis(options, plotData, nRow, start)
			ylabel('Phase (rad):')
			for i in range(1, nColumn):
				if 'lines' in options:
					plot(plotData[0], plotData[i], color=colors[(i-1)%len(colors)], linewidth=1)
				else:
					scatter(plotData[0], plotData[i], color=colors[(i-1)%len(colors)], s=3)
		elif op == 'delay':
			toneColumns = getTwoToneColumns(toneSelect, d)
			if options['verbose'] > 0:
				print 'From %s:%s, the following tone columns will be used:' % (d.filename, d.section), toneColumns
			nColumn = 1+len(toneColumns)
			nRow = len(d.records)
			plotData = np.zeros((nColumn, nRow))
			nPoints += nRow * len(toneColumns)
			for row in range(nRow):
				calculateDelays(plotData, row, d.records[row], toneColumns)

			makeTimeAxis(options, plotData, nRow, start)
			ylabel('Delay (ps):')
			for i in range(1, nColumn):
				if 'lines' in options:
					plot(plotData[0], plotData[i], color=colors[(i-1)%len(colors)], linewidth=1)
				else:
					scatter(plotData[0], plotData[i], color=colors[(i-1)%len(colors)], s=3)
		elif op == 'xy':
			toneColumns = getOneToneColumns(toneSelect, d)
			if options['verbose'] > 0:
				print 'From %s:%s, the following tone columns will be used:' % (d.filename, d.section), toneColumns
			nColumn = len(toneColumns)
			nRow = len(d.records)
			plotDataX = np.zeros((nColumn, nRow))
			plotDataY = np.zeros((nColumn, nRow))
			nPoints += nRow * nColumn
			for row in range(nRow):
				calculateComplex(plotDataX, plotDataY, row, d.records[row], toneColumns)

			xlabel('Real:')
			ylabel('Imag:')
			for i in range(nColumn):
				if 'lines' in options:
					plot(plotDataX[i], plotDataY[i], color=colors[(i-1)%len(colors)], linewidth=1)
				else:
					scatter(plotDataX[i], plotDataY[i], color=colors[(i-1)%len(colors)], s=3)

	if nPoints == 0:
		print 'Nothing to see!  No data points generated.'
	else:		
		print 'Plotted %d points' % nPoints
		#savefig('%s-pcal-%s.pdf' % (a, op), bbox_inches='tight')
		show()



op, toneSelect, fileList, options = parseCommandLine()

data = []
for f in fileList:
	fd = File(f)
	sections = fd.split(options['maxgap'])
	for s in sections:
		if len(s.records) > options['quack']+2:
			if options['quack'] > 0:
				print 'Quacking %d records' % options['quack']
				s.quack(options['quack'])
			s.summary()
			data.append(s)
		else:
			print 'Dropping %s:%s because it has only %d records' % (s.filename, s.section, len(s.records))

antennas = antList(data)

for A in antennas:
	D = getAntennaData(data, A)
	print 'Processing %s: %d files' % (A, len(D))
	generatePlots(op, toneSelect, D, options)

# vim: tabstop=8:softtabstop=8:shiftwidth=8:noexpandtab