#!/usr/bin/env python3

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets  import RectangleSelector
import argparse, math
from datetime import datetime

def db(x):
    """ Convert linear value to dB value """
    return 10*np.log10(x)

parser = argparse.ArgumentParser()
parser.add_argument('-x', '--xcolumn', help="Column to use as X-axis", type=int)
parser.add_argument('-c', '--column', help="Column to use as Y-axis", type=int, action='append', required=True)
parser.add_argument('--skip', '-skip', help="Skip header lines", type=int)
parser.add_argument('files', nargs='+')
parser.add_argument('--xfilt', help="Ignore any rows with X value larger", type=float)
parser.add_argument('--yfilt', help="Ignore any rows with Y value larger", type=float)
parser.add_argument('--yfiltmin', help="Ignore any rows with Y value smaller", type=float)
parser.add_argument('--xoffset', help="Add this to all X values", type=float, default=0)
parser.add_argument('--xaccumoffset', help="Add this incrementally to all X values between plots", type=float)
parser.add_argument('--yaccum', '-yaccum', help="Add this incrementally to all Y values, per point", type=float)
parser.add_argument('--yscale', '-yscale', help="Scale 'Y' values by this", type=float)
parser.add_argument('--xscale', '-xscale', help="Scale 'X' values by this", type=float)
parser.add_argument('--ydiv', '-ydiv', help="Divide 'Y' values by this", type=float)
parser.add_argument('--xdiv', '-xdiv', help="Divide 'X' values by this", type=float)
parser.add_argument('-o', '--outfile', help="output file to save plot as")
parser.add_argument('--label', '-label', help="Label for legend", action='append')
parser.add_argument('--flabel', '-flabel', help="Add filename to legend", action='store_true')
parser.add_argument('-n', '--norm', '-norm', help="Normalise by average of these columns", type=int, action='append')
#parser.add_argument('--xticks', help="Override X labels. Must equal number X values", action='append')
parser.add_argument('--dpi', help="Resolution to save output as", type=int, default=300)
parser.add_argument('-t', '--title', '-title', help="Plot title")
parser.add_argument('--xlab', '-xlab', help="Label for X-axis")
parser.add_argument('--ylab', '-ylab', help="Label for Y-axis")
parser.add_argument('--xmax', help="Set maximum X value", type=float)
parser.add_argument('--xmin', help="Set minimum X value", type=float)
parser.add_argument('--ymax', help="Set maximum Y value", type=float)
parser.add_argument('--ymin', help="Set minimum Y value", type=float)
parser.add_argument('-s', '--show', help="Display plot to screen", action="store_true")
parser.add_argument('-l', '--log', help="Take log of y value", action="store_true")
parser.add_argument('-P', '-panel', '--panel', help="Each plot in seperate panel", action="store_true")
parser.add_argument('-paneltitle', '--paneltitle', help="Add filename to each panel", action="store_true")
parser.add_argument('-g', '--globalscale', help="Global scaling of panel plot", action="store_true")
parser.add_argument('--points', help="Use points, not lines", action="store_true")
parser.add_argument('--db', '-db', help="Plot as dB, not linear", action="store_true")
parser.add_argument('-p', '--poly', help="Subtract polynomial of order P from data", type=int)
parser.add_argument('--trunc', help="Truncate the xvalues modulue value", type=float)
parser.add_argument('--fontsize', '-fontsize', help="Set font size", type=float)
parser.add_argument('--linewidth', '-linewidth', help="Set line width", type=float)
parser.add_argument('--roll', '-roll', help="Swap lower and upper half of values", action="store_true")
parser.add_argument('--smooth', '-smooth', help="Smooth the values (N channels)", type=int)
parser.add_argument('--wrap', '-wrap', help="Constrain yvalues to +- this value", type=float)
parser.add_argument('--transparent', '-transparent', help="Save plot image with transparent background", action="store_true")
parser.add_argument('-i', '-interactive', '--interactive', help="Return stats of selected regions", action="store_true")
parser.add_argument('-m', '--mark', '-mark', help="Draw horizontal lines at given values", type=float, action='append')
parser.add_argument('--fullscreen', '-fullscreen', help="Maximum size of matplotlib window",action='store_true')
parser.add_argument('--time', '-time', help="Intepret X-axis column as time", action='store_true')

args = parser.parse_args()

if args.skip is not None:
    skip = args.skip
else:
    skip = 0

if args.xscale is not None and args.xdiv is not None:
    print("Error: Cannot pass -xscale and -xdiv")
    exit(1)

if args.yscale is not None and args.ydiv is not None:
    print("Error: Cannot pass -yscale and -ydiv")
    exit(1)

doglobal = args.globalscale

xcol = args.xcolumn

show=(args.show or (args.outfile==None))

title = args.title
if args.points:
    point = '.'
else:
    point = ''

subPoly = args.poly

if args.fontsize is not None:
    plt.rcParams.update({'font.size': args.fontsize})

t0 = None
def load_data_with_time(file, time_col=0, skip=0, time_format="%H:%M:%S.%f"):
    global t0
    # Load all data as string
    data = np.genfromtxt(file, dtype=str, skip_header=skip)

    # Parse time column to datetime
    time_strs = data[:, time_col]
    times = [datetime.strptime(t, time_format) for t in time_strs]
    if t0 is None:
        t0 = times[0]
    x = np.array([(t - t0).total_seconds() for t in times])

    # Convert all entries to float, replacing non-numeric entries with np.nan
    def safe_float(val):
        try:
            return float(val)
        except ValueError:
            return np.nan

    vectorized_float = np.vectorize(safe_float)
    data_float = vectorized_float(data.astype(str)).astype(float)

    # Replace time column with numeric seconds
    data_float[:, time_col] = x

    return data_float    

data = []
for file in args.files:
    if args.time:
        d = load_data_with_time(file, xcol-1, skip)
    else:
        d=np.genfromtxt(file, dtype=float, skip_header=skip)
    data.append(d)

def subtractFit(y, x, order=1):
    fit = np.polyfit(x,y,order)
    poly = np.poly1d(fit)
    fitY = poly(x)
    return(y-fitY)

gmax = None
gmin = None

xaccum = 0

if args.interactive:
    plotData = []
    rmean = None
    rrms = None

def calcNorm(d, col):
    product = 1
    for i in col:
        i -= 1
        c = d[:,i]
        product *= np.mean(c)
    return math.pow(product, 1/len(col))

def doplot(d, xcol, ycol, i, nx=None, ny=None, label=None, norm=1, panel=False):
    global xaccum, plotData
    ycol -= 1

    if args.yscale is not None:
        yscale = args.yscale
    elif args.ydiv is not None:
        yscale = 1/args.ydiv
    else:
        yscale = 1
    y = d[:,ycol]/norm*yscale

    if args.xscale is not None:
        xscale = args.xscale
    elif args.xdiv is not None:
        xscale = 1/args.xdiv
    else:
        xscale = 1
    x = d[:,xcol]*xscale+args.xoffset

    if args.roll:
        n = len(y)
        y = np.concatenate([y[n//2:n],y[0:n//2]])

    if xcol is not None:
        xcol -= 1
        if args.xoffset is not None:
            xoff = args.xoffset
        else:
            xoff = 0
        x = d[:,xcol]*xscale+xoff
    else:
        x = np.arange(len(y))*xscale

    if (args.xfilt is not None):
        index = x<=args.xfilt
        x = x[index]
        y = y[index]

    if (args.yfilt is not None):
        index = y<=args.yfilt
        x = x[index]
        y = y[index]

    if (args.yfiltmin is not None):
        index = y>=args.yfiltmin
        x = x[index]
        y = y[index]

    if args.yaccum is not None:
        y += np.arange(len(y))*args.yaccum

    if args.wrap is not None:
        y = np.remainder(y + args.wrap, args.wrap*2) - args.wrap
        
    if args.log:
        y = np.log(y)
    elif args.db:
        y = db(y)
        y -= np.max(y)
        
    if subPoly != None:
        y = subtractFit(y, x, subPoly)

    if args.trunc is not None:
        x = np.fmod(x, args.trunc)

    if nx is not None:
        ax = plt.subplot(nx, ny, i+1)
        if args.paneltitle or args.flabel:
            ax.text(0.05,0.95,args.files[i], transform=ax.transAxes, ha='left', va='top')

    if args.smooth is not None:
        span = args.smooth
        y = np.convolve(y, np.ones(span * 2 + 1) / (span * 2 + 1), mode="valid")
        x = x[span:-span]

    pl, = plt.plot(x+xaccum, y, point, label=label)
    if args.interactive:
        plotData.append((x+xaccum,y,label))
    else:
        pl.set_picker(True)
    
    if args.linewidth is not None:
        pl.set_linewidth(args.linewidth)

    if args.xaccumoffset is not None: xaccum += args.xaccumoffset


    if panel and label is not None:
        ax.text(0.05,0.95,label, transform=ax.transAxes, ha='left', va='top')
    
    #lastx=x

#    if len(args.xticks)>0:
#        plt.xticks(lastx, args.xticks)


def drawMarks(marks, ax):
    for m in marks:
        ax.axhline(y=m, color='lightblue', linewidth=2)

if args.label is None:
    if args.flabel:
        label=args.files
    else:
        label = (None,) * len(args.files)
else:
    label = args.label
    
nx = None
ny = None
if args.panel:
    nplot = len(data)
    nx = math.ceil(math.sqrt(nplot))
    ny = math.ceil(nplot/nx)
else:
    doglobal = False

if args.interactive:
    rsignal = None
    rreference = None

    def inRect(x1, x2, y1, y2, data, reference):
        global rmean, rrms
        xpoints = []
        ypoints = []
        label = data[2]

        
        for x,y in zip(data[0], data[1]):
            if x>x1 and x<x2 and y>y1 and y<y2:
                xpoints.append(x)
                ypoints.append(y)
        xpoints = np.array(xpoints)
        ypoints = np.array(ypoints)

        ymax = ypoints.max()
        ymin = ypoints.min()
        ymean = ypoints.mean()

        if label is not None:
            print(label, ": ")
        print("Max = {}".format(ymax))
        print("Min = {}".format(ymin))
        print("Mean = {}".format(ymean))
        
        if reference:
            ystd = ypoints.std()            
            print("Std = {}".format(ystd))
            rmean = ymean
            rrms = ystd
        else:
            if rmean is not None and rrms is not None:
                print("SNR = {}".format((ymax-rmean)/rrms))
                #print("SNR= ", ymax, rmean, rrms)

    def addRect(rect, x1, x2, y1, y2, reference=True):
        if rect is not None:
            rect.remove()

        if reference:
            colour = 'g'
        else:
            colour = 'r'

        rect = plt.Rectangle( (min(x1,x2),min(y1,y2)), np.abs(x1-x2),
                np.abs(y1-y2),linestyle='--',edgecolor=colour,facecolor=None, fill=False)

        plt.gca().add_patch(rect)
        for d in plotData:
            inRect(x1,x2,y1,y2,d,reference)
        return rect
    
    def line_select_callback(eclick, erelease):
        global rsignal, rreference
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata
        if x1==x2 and y1==y2:
            return
    
        if (x1>x2): x1,x2 = x2,x1
        if (y1>y2): y1,y2 = y2,y1
    
        if eclick.button==1:
            rsignal = addRect(rsignal, x1, x2, y1, y2, False)
        else:
            rreference = addRect(rreference, x1, x2, y1, y2, True)

l = 0
p = 0
for d in data:

    if args.norm is not None:
        norm = calcNorm(d, args.norm)
    else:
        norm = 1
        
    for c in args.column:
        doplot(d, xcol, c, l, nx, ny, label[p], norm, args.panel)
        l += 1
    p += 1
        
if args.mark is not None:
    if nx is not None:
        for i in range(len(data)):
            ax = plt.subplot(nx, ny, i+1)
            drawMarks(args.mark, ax)
    else:
        drawMarks(args.mark, plt)
        
if (args.label is not None or args.flabel) and not args.panel: plt.legend()

if (title!=None):
    if args.panel:
        plt.suptitle(title)
        plt.tight_layout(rect=[0, 0, 1, 0.95])  # Leave space at top for suptitle
    
        pass
    else:
        plt.title(title)
if (args.xlab!=None):
    plt.xlabel(args.xlab)
if (args.ylab!=None):
    plt.ylabel(args.ylab)

if doglobal:
    nplot = len(data)
else:
    nplot = 1
        
if doglobal:  # Currently only useful with panel option
    for i in range(nplot):
        plt.subplot(nx, ny, i+1)
        if i==0:
            (gmin, gmax) = plt.ylim()
        else:
            (ggmin, ggmax) = plt.ylim()
            if ggmin<gmin: gmin = ggmin
            if ggmax>gmax: gmax = ggmax
    if args.ymin is None: args.ymin = gmin
    if args.ymax is None: args.ymax = gmax

if args.xmin!=None or args.xmax!=None:
    for i in range(nplot):
        if nplot>1: plt.subplot(nx, ny, i+1)
        (xmin, xmax) = plt.xlim()
        if args.xmin!=None: xmin = args.xmin
        if args.xmax!=None: xmax = args.xmax
        plt.xlim(xmin, xmax)

if args.ymin!=None or args.ymax!=None:
    for i in range(nplot):
        if nplot>1: plt.subplot(nx, ny, i+1)
        (ymin, ymax) = plt.ylim()
        if args.ymin!=None: ymin = args.ymin
        if args.ymax!=None: ymax = args.ymax
        plt.ylim(ymin, ymax)

plt.tight_layout()

if (args.outfile is not None):
    plt.savefig(args.outfile, dpi=args.dpi, transparent=args.transparent)

def on_pick(event):
    line = event.artist
    label = line.get_label()
    print(f"You clicked on {label}")

if show:
    if args.interactive:
        rs = RectangleSelector(plt.gca(), line_select_callback, useblit=True, button=[1,3], spancoords='data', interactive=False)
    else:
        # Connect the pick event to the pick handler function
        plt.gcf().canvas.mpl_connect('pick_event', on_pick)

    if args.fullscreen:
        plt.get_current_fig_manager().full_screen_toggle()        
        
    plt.show()
