#!/usr/bin/env python3

# Note: this utility can run under python2.7 or python3

import getopt
from sys import argv, exit, stdout
from os import environ, system, getcwd
from re import search
from time import sleep
from math import sqrt

program = 'difxspeed'
version = '0.6'
verdate = '20190916'
author = 'Walter Brisken <wbrisken@nrao.edu>'

inputFileExtension = '.difxspeed'

# TODO
# 2. read antenna list from vex
# 5. get dummy timestamps; populate column descriptions

def usage(pgm):
	print('\n%s ver. %s  %s  %s' % (program, version, author, verdate))
	print('\nUsage: %s <benchmarkFile> [<numIterations>]\n' % pgm)
	print('Where:')
	print('  <benchmarkFile> describes the series of benchmarks to run.')
	print('                  (must end in %s)\n' % inputFileExtension)
	print('  <numIterations> is the number of times to run the test.\n')
	print('See https://www.atnf.csiro.au/vlbi/dokuwiki/doku.php/difx/difxspeed for')
	print('more information.\n')

nExecute = 0
def execute(cmd, verbose, pretend):
	global nExecute
	nExecute += 1
	if pretend:
		print('[%d] In: %s\n[%d] Pretending: %s\n' % (nExecute, getcwd(), nExecute, cmd))
	else:
		print('[%d] In: %s\n[%d] Executing: %s\n' % (nExecute, getcwd(), nExecute, cmd))
		system(cmd)

def firstconfig(params):
	config = []
	for p in params.keys():
		config.append([p, 0, params[p]])
	return config

def nextconfig(config):
	for c in config:
		c[1] += 1
		if c[1] >= len(c[2]):
			c[1] = 0
		else:
			return config
	
	return None

def printconfig(config):
	print('Congig:')
	for c in config:
		print('  %s = %s' % (c[0], c[2][c[1]]))

def writeparams(fd, config):
	for c in config:
		if len(c[2]) > 1:
			fd.write('%s ' % c[2][c[1]])

def writeheader(fd, config, cluster):
	fd.write('# Cluster definition:\n')
	for c in cluster.keys():
		fd.write('# %s =' % c)
		for v in cluster[c]:
			fd.write(' %s' % v)
		fd.write('\n')
	fd.write('\n')
	fd.write('# Fixed parameters:\n')
	for c in config:
		if len(c[2]) == 1:
			fd.write('# %s = %s\n' % (c[0], c[2][0]))
	fd.write('\n')
	fd.write('# Table columns:\n')
	col = 1
	for c in config:
		if len(c[2]) > 1:
			fd.write('# %d  %s\n' % (col, c[0]))
			col += 1
	fd.write('# %d  Average execute time (seconds)\n' % col)
	col += 1
	fd.write('# %d  RMS of execute times (seconds)\n' % col)
	col += 1
	fd.write('# %d...  Individual execute times (seconds)\n' % col)

	fd.flush()

def getnant(config, cluster):
	maxAnt = len(cluster['antennas'])
	for c in config:
		if c[0] == 'nAnt':
			nAnt = int(c[2][c[1]])
			if nAnt > maxAnt:
				print('Error: %d antennas is more than listed (%d)' % (nAnt, maxAnt))
			else:
				return nAnt
	# default to all antennas
	return maxAnt

def getncore(config, cluster):
	maxCore = len(cluster['cores'])
	for c in config:
		if c[0] == 'nCore':
			nCore = int(c[2][c[1]])
			if nCore > maxCore:
				print('Error: %d cores is more than listed (%d)' % (nCore, maxCore))
			else:
				return nCore

	return maxCore

def runtest(fd, config, cluster, filebase, verbose, run, doFake, remainder):

	setupParams = ['tInt', 'fftSpecRes', 'specRes', 'strideLength', 'xmacLength', 'numBufferedFFTs', 'subintNS', 'doPolar']
	antennaParams = ['toneSelection']
	dontWriteParams = ['nAnt']
	writeparams(fd, config)

	if 'DIFX_CALC_PROGRAM' in environ:
		calcProgram = environ['DIFX_CALC_PROGRAM']
	else:
		calcProgram = 'calcif2'

	nAnt = getnant(config, cluster)
	nCore = getncore(config, cluster)
	antennas = cluster['antennas']

	if run:
		# write v2d file
		v2dfile = filebase + '.v2d'
		o = open(v2dfile, 'w')
		o.write('# written by %s ver %s\n\n' % (program, version))

		if doFake:
			o.write('fake = true\n')
		o.write('singleScan = false\n')
		o.write('startSeries = 0\n')

		# machines info goes here
		o.write('machines = %s' % environ['DIFX_HEAD_NODE'])
		for a in range(nAnt):
			o.write(',%s' % cluster['datastreams'][a])
		for c in range(nCore):
			o.write(',%s' % cluster['cores'][c])
		o.write('\n')

		o.write('antennas = ')
		for a in range(nAnt):
			if a > 0:
				o.write(',')
			o.write(antennas[a])
		o.write('\n')
		# only write antenna sections if in fake mode
		if doFake:
			for a in range(nAnt):
				# FIXME: add other antenna parameters here
				o.write('ANTENNA %s\n' % antennas[a])
				o.write('{\n')
				o.write('  source = fake\n')
				o.write('}\n')
		
		# write non-setup parameters here
		for c in config:
			if c[0] not in setupParams and c[0] not in antennaParams and c[0] not in dontWriteParams:
				o.write('%s = %s\n' % (c[0], c[2][c[1]]))

		# write setup
		o.write('SETUP default\n')
		o.write('{\n')
		for c in config:
			if c[0] in setupParams and c[0] not in dontWriteParams:
				o.write('  %s = %s\n' % (c[0], c[2][c[1]]))
		o.write('}\n')

		# write remainder
		if len(remainder) > 0:
			o.write('\n')
			for r in remainder:
				o.write(r)

		o.close()

		# run vex2difx
		execute('vex2difx %s' % v2dfile, verbose, False)

		# run calcif2
		execute('rm -f %s.im' % filebase, verbose, False)
		execute('%s %s.calc' % (calcProgram, filebase), verbose, False)

		# run difx
		execute('startdifx -n -f %s' % filebase, verbose, False)

		# make sure difxlog is closed
		sleep(1)

	try:
		executeTime = []
		n = 0
		s = 0.0
		ss = 0.0
		logdata = open('%s.difxlog' % filebase, 'r').readlines()
		for l in logdata:
			m = search('\*\*([0-9.]+)\*\*', l)
			if m != None:
				et = float(m.group(1))
				n += 1
				s += et
				ss += et*et
				executeTime.append(et)

		if n > 0:
			a = s/n
			r = sqrt(ss/n - a*a);
			fd.write(' %f %f ' % (a, r))
			for et in executeTime:
				fd.write(' %f' % et)

	except IOError:
		print('Error reading %s.difxlog' % filebase)
		return -1


	fd.write('  # %s' % filebase)
	fd.write('\n')
	fd.flush()

	return 0

def benchmark(basename, verbose, run):
	inputFile = basename + inputFileExtension
	params = {}
	cluster = {}
	clusterParams = ['datastreams', 'cores', 'antennas']
	requiredParams = ['vex', 'nThread']
	setupParams = ['tInt', 'doPolar']
	doFake = True

	fd = open('%s.out' % basename, 'w')

	# parse input file
	data = open(inputFile).readlines()

	l = len(data)
	remainder = []  # explicitly include in .v2d files
	for i in range(l):
		d = data[i]
		if d.strip() == '--':
			remainder = data[i+1:]
			break
		c = d.split('#')[0]
		s = c.strip().split('=')
		if len(s) == 2:
			key = s[0].strip()
			if key == 'fake' and s[1].strip().lower() in ['f', 'false', '0']:
				doFake = False
			else:
				L = []
				for v in s[1].split(','):
					L.append(v.strip())
				if key in clusterParams:
					cluster[key] = L
				else:
					params[key] = L
	bad = 0
	for r in clusterParams:
		if not r in cluster.keys():
			print('Error: required parameter %s not set' % r)
			bad += 1
	for r in requiredParams:
		if not r in params.keys():
			print('Error: required parameter %s not set' % r)
			bad += 1
	if bad > 0:
		return -1

	# generate an iterator-type object to keep track of the configurations
	it = firstconfig(params)

	writeheader(fd, it, cluster)
	sequenceNumber = 1

	while it != None:
		filebase = '%s-%03d' % (basename, sequenceNumber)
		if sequenceNumber == 1:
			# run a first dummy test
			v = runtest(fd, it, cluster, '%s-dummy' % basename, verbose, run, doFake, remainder)
			if v < 0:
				break
		v = runtest(fd, it, cluster, filebase, verbose, run, doFake, remainder)
		if v < 0:
			break
		it = nextconfig(it)
		sequenceNumber += 1
	
	if fd != stdout:
		fd.close()

def testenviron():
	requiredEnviron = ['DIFX_HEAD_NODE']

	bad = 0
	for e in requiredEnviron:
		if not e in environ.keys():
			print('Required environment variable %s not set' % e)
			bad += 1

	if bad > 0:
		exit(0)


def getbasename(infile):
	m = search('^([a-zA-Z][a-zA-Z0-9]*)%s' % inputFileExtension, infile)
	if m == None:
		return None
	else:
		return m.group(1)

#---

verbose = 1
nIter = 1

if len(argv) < 2:
	usage(argv[0])
	exit(0)

if argv[1] in ['-h', '--help']:
	usage(argv[0])
	exit(0)

if len(argv) > 2:
	nIter = int(argv[2])

testenviron()

basename = getbasename(argv[1])
if basename == None:
	print('Error: input file must end with %s, must start with a letter, and must not have any special characters in it.' % inputFileExtension)
	exit(0)

if nIter > 0:
	for i in range(nIter):
		print('Iteration %d of %d' % (i+1, nIter))
		benchmark(basename, verbose, True)
else:
	benchmark(basename, verbose, False)