#!/usr/bin/env python3

import os
import collections
import numpy as np
import argparse
import surfa as sf


# parse input
parser = argparse.ArgumentParser(description='creates a table of the differences between two stats tables')
parser.add_argument('--t1', required=True, help='input table 1 (output of asegstats2table or aparcstats2table)')
parser.add_argument('--t2', required=True, help='input table 2 (output of asegstats2table or aparcstats2table)')
parser.add_argument('--o', required=True, help='output file')
parser.add_argument('--percent', action='store_true', help='compute percent diff with respect to mean of tables')
parser.add_argument('--percent1', action='store_true', help='compute percent diff with respect to table1')
parser.add_argument('--percent2', action='store_true', help='compute percent diff with respect to table2')
parser.add_argument('--mul', nargs=1, default=1, type=float, help='multiply by mulval')
parser.add_argument('--div', nargs=1, default=1, type=float, help='divide by divval')
parser.add_argument('--common', action='store_true', help='compute diff on common segs (may reorder)')
parser.add_argument('--rm-exvivo', action='store_true', help='remove the string "_exvivo" from the column header')
parser.add_argument('--diff-subjs', action='store_true', default=False, dest='diff', help='pass this flag to compare subjects with different names')
parser.add_argument('--noreplace53', action='store_true', help='do not replace 5.3 structure names with later names')
args = parser.parse_args()

# table reading utility
def read_table(tablefile,Replace53):

    # check if file exists
    if not os.path.isfile(tablefile):
        sf.system.fatal('table file does not exist at %s' % tablefile)

    # read table data
    with open(tablefile, 'r') as file:
        header = file.readline()
        if args.rm_exvivo:
            header = header.replace('_exvivo', '')
        header = header.split()
        measure = header[0].split(':')[-1]
        structlist = header[1:]
        if(Replace53):
            structlist = ['Left-Thalamus'  if i=='Left-Thalamus-Proper' else i for i in structlist];
            structlist = ['Right-Thalamus' if i=='Right-Thalamus-Proper' else i for i in structlist];
            structlist = ['CorticalWhiteMatterVol' if i=='CerebralWhiteMatterVol' else i for i in structlist];
            structlist = ['lhCorticalWhiteMatterVol' if i=='lhCerebralWhiteMatterVol' else i for i in structlist];
            structlist = ['rhCorticalWhiteMatterVol' if i=='rhCerebralWhiteMatterVol' else i for i in structlist];
        values = collections.OrderedDict()
        for line in file:
            splitline = line.split()
            if not splitline:
                continue
            subject = splitline[0]
            values[subject] = {struct: float(val) for struct, val in zip(structlist, splitline[1:])}

    return (structlist, values, measure)

# read tables
structlist1, values1, measure1 = read_table(args.t1,not args.noreplace53)
structlist2, values2, measure2 = read_table(args.t2,not args.noreplace53)

# check if measures are the same
if measure1 != measure2:
    sf.system.fatal('table 1 measure (%s) does not match table 2 measure (%s)' % (measure1, measure2))

# check for subject and structure differences
def check_differences(a, b, dtype):
    common = [x for x in a if x in b]
    if not common:
        sf.system.fatal('tables have no common %ss!' % dtype)
    for tnum, lst in enumerate((a, b)):
        uniques = [x for x in lst if x not in common]
        for unique in uniques:
            print('info: table %d has unique %s %s' % (tnum + 1, dtype, unique))
        if uniques and not args.common:
            sf.system.fatal('tables have differing %ss, to diff only on common data, use the --common flag' % dtype)
    return common

structures = check_differences(structlist1, structlist2, 'structure')

# get list of subjects from both tables
subs1 = [x for x in values1.keys()]
subs2 = [x for x in values2.keys()]

if not ((subs1 == subs2) or args.diff):
    sf.system.fatal('subjects are not the same, to run on different subjects pass the --diff-subjs flag')

if len(subs1) != len(subs2):
    sf.system.fatal('tables have differing number of subjects, must have same number of subjects in both tables') 

# will be subject names in output table
subjects = [str(x[0]) + ',' + str(x[1]) for x in zip(subs1,subs2)]

# build matrices from common data
mat1 = np.zeros((len(subjects), len(structures)))
mat2 = np.zeros((len(subjects), len(structures)))
for i, subject in enumerate(subjects):
    for j, structure in enumerate(structures):
        mat1[i, j] = values1[subs1[i]][structure]
        mat2[i, j] = values2[subs2[i]][structure]

# compute diff
diff = (mat1 - mat2) * args.mul / args.div

# check if computing percent diff
perc_den = None
if args.percent:
    perc_den = (mat1 + mat2) / 2.0
elif args.percent1:
    perc_den = mat1
elif args.percent2:
    perc_den = mat2

# compute percent diff
if perc_den is not None:
    diff = 100 * diff / perc_den
    a = np.transpose(np.nonzero(perc_den == 0))
    diff[a[:, 0], a[:, 1]] = 0

# write table
with open(args.o, 'w') as file:
    file.write('Measure:%s-diff  %s\n' % (measure1, '  '.join(structures)))
    for i, subject in enumerate(subjects):
        file.write('%s  %s\n' % (subject, '  '.join(['%f' % val for val in diff[i]])))
