#!/usr/bin/env python3

# user will need dev (3/15/2018) version of mri_surfcluster and mri_volcluster

import os
import sys
import shlex
import argparse

from glob import glob
from math import log10
from subprocess import call


def errExit(message, ret=1):
  print('error: ' + message)
  sys.exit(ret)


# matlab can't handle .nii.gz - must convert to .nii
def convertNifty(src_vol, trg_vol):
  print('"nii.gz" format is unsupported - converting %s to %s' % (src_vol, trg_vol))
  ret = call('mri_convert %s %s' % (src_vol, trg_vol), shell=True)
  if ret != 0:
    sys.exit(ret)
  return trg_vol


# if a user does this: --pargs "-flag", then argparse will fail because it interprets "-flag"
# as a seperate argument. This can be fixed by using --pargs="-flag", so here we override
# the ArgumentParser error function to notify users if this issue occurs
class ArgumentParser(argparse.ArgumentParser):
  def error(self, message):
    exc = sys.exc_info()[1]
    if exc and (exc.argument_name == '--pargs'):
      message = 'additional arguments should be passed to the palm function with this syntax: --pargs="-flag1 -flag2"'
    self.print_usage(sys.stderr)
    args = {'prog': self.prog, 'message': message}
    self.exit(2, ('%(prog)s: error: %(message)s\n') % args)


#-------- main -------------------------
version = 'fspalm';

# ---- parse args ----
description = """\
Prepares and analyzes the output of mri_glmfit for Permutation
Analysis of Linear Models (PALM) to correct for multiple
comparisons. See https://surfer.nmr.mgh.harvard.edu/fswiki/FsPalm
for more info. PALM was written by Dr. Anderson Winkler, see
fsl.fmrib.ox.ac.uk/fsl/fslwiki/PALM"""

parser = ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter)
parser._action_groups.pop()  #  list required arguments before optional
# required args
required = parser.add_argument_group('required arguments')
required.add_argument('--glmdir', metavar='PATH', type=str, help='the mri_glmfit directory to prepare', required=True)
required.add_argument('--cft', metavar='FLOAT', type=str, help='voxel-wise cluster forming threshold (CFT), -log10(p)', required=True)
required.add_argument('--cwp', metavar='FLOAT', type=float, help='clusterwise p-value threshold', required=True)
# tail - mutually exclusive flags
mutually_exclusive = parser.add_argument_group('required flags (use one of the following)')
tails = mutually_exclusive.add_mutually_exclusive_group(required=True)
tails.add_argument('--onetail', action='store_true', help='perform a one-tailed test')
tails.add_argument('--twotail', action='store_true', help='perform a two-tailed test. NOTE: changes CFT')
# optional args
optional = parser.add_argument_group('optional arguments')
optional.add_argument('--name', metavar='DIRNAME', type=str, help='name of palm subdirectory (default="palm")', default='palm')
optional.add_argument('--iters', metavar='INT', help='number of iterations', default=10)
optional.add_argument('--monly',  action='store_true', help='only create matlab file, do not run')
optional.add_argument('--pponly', action='store_true', help='Only perform post-processing')
optional.add_argument('--octave', action='store_true', help='run with octave, not matlab')
optional.add_argument('--centroid', action='store_true', help='add --centroid flag to mri_surfcluster post-processing')
optional.add_argument('--2spaces', action='store_true', help='Bonferroni-correct for 2 spaces')
optional.add_argument('--3spaces', action='store_true', help='Bonferroni-correct for 3 spaces')
optional.add_argument('--pargs', metavar='SUBARGS', type=str, help='supply additional args to be passed to the palm function')
# check for no args before parsing
if len(sys.argv) == 1:
  parser.print_help()
  sys.exit(1)
argv0 = sys.argv;
args = parser.parse_args()

# retrieve any additional palm arguments
extra_opts = []
if args.pargs:
  extra_opts += [ "'%s'" % x for x in shlex.split(args.pargs)]

# two-tailed test
if args.twotail:
  extra_opts.append("'-twotail'")

# check glmdir
glmdir = os.path.abspath(args.glmdir)
if not os.path.isdir(glmdir):
  errExit('"%s" is not a valid directory' % glmdir)
# make palm subdir
palmdir = os.path.join(glmdir, args.name)
print('preparing a palm subdirectory in %s' % palmdir)
# os.makedirs(palmdir, exist_ok=True)  # python3
if not os.path.exists(palmdir): os.makedirs(palmdir)  # settle for python2 for now

freesurferhome = os.environ.get('FREESURFER_HOME', '')

LFfile = os.path.join(palmdir, 'fspalm.log')
LF = open(LFfile,'w');
LF.write("fspalm log file %s\n" % (version));
LF.write("FREESURFER_HOME %s\n" % (freesurferhome));
LF.write("setenv SUBJECTS_DIR %s\n" % (os.environ['SUBJECTS_DIR']));
LF.write("cd %s\n" % os.getcwd());
for arg in argv0:
  LF.write(arg);
  LF.write(" ");
LF.write("\n\n");
LF.write('palm subdirectory %s' % palmdir)

# check whether this is a surface-based analysis
surface_analysis = False
if os.path.isfile(os.path.join(glmdir, 'surface')):
  surface_analysis = True
  print('found %s/surface file - this is a surface-based analysis' % os.path.basename(glmdir))
else:
  print('this is a volume-based analysis')

# ---- glm info ----
glm_log = os.path.join(glmdir, 'mri_glmfit.log')
if not os.path.isfile(glm_log):
  errExit('cannot find mri_glmfit.log in the glm directory')
with open(glm_log, 'r') as f:
  loglines = [line for line in f.readlines() if line.strip()]

# extract input filename
input_file = None
for line in loglines:
  firstWord = line.split()[0]
  if firstWord == 'y':
    input_file = line.split()[1]
  elif firstWord == 'cmdline' and '--fwhm' in line:
    # sanity check to make sure fwhm was not used
    print('ERROR: --fwhm was used in mri_glmfit. To run fspalm, explicitly smooth your data using mris_fwhm or mri_fwhm, rerun mri_glmfit without --fwhm, then run fspalm')
    sys.exit(1)
if not input_file:
  errExit("could not extract input filename from mri_glmfit.log")
if not os.path.isfile(input_file):
  errExit('mri_glmfit input file "{0}" is no longer available at this location'.format(input_file))
if input_file.endswith('.nii.gz'):
  # nii.gz input can't be used - must convert to nii
  input_file = convertNifty(input_file, os.path.join(palmdir, os.path.basename(input_file).split('.')[0] + '.nii'))
print('using the glmfit input file %s' % input_file)

# get output file format
if input_file.endswith('.nii'): oformat = '.nii'
if input_file.endswith('.mgh'): oformat = '.mgz'
if input_file.endswith('.mgz'): oformat = '.mgz'

# extract hemi (and subject info for surface-based analysis)
hemi = None
if surface_analysis:
  # use the glmdir/surface file
  with open(os.path.join(glmdir, 'surface'), 'r') as f:
    targ_subject, hemi = f.readline().strip().split()
  subjects_dir = os.environ.get('SUBJECTS_DIR', '')
  if not subjects_dir:
    errExit('you must set the SUBJECTS_DIR envionment variable to do surface-based analysis')
  # get the surface file path
  surface_file = os.path.join(subjects_dir, targ_subject, 'surf', hemi + '.white')
  extra_opts += ["'-s'", "surffile"]
  # run a couple more sanity checks
  if not os.path.isdir(os.path.join(subjects_dir, targ_subject)):
    errExit('subject %s does not exist in current SUBJECTS_DIR (%s)' % (targ_subject, subjects_dir))
  if not os.path.isfile(surface_file):
    errExit('subject file %s does not exist' % surface_file)
  # see whether there is an average vertex area file
  area_file = os.path.join(subjects_dir, targ_subject, 'surf', hemi + '.white' + '.avg.area.mgh')
  if os.path.isfile(area_file):
    print('using area file %s' % area_file);
    extra_opts += [" areafile"];
  else:
    area_file = "";
else:
  # could be a cvs subject, need a way to figure this out
  # mri_glmfit-sim does not figure it out either, user spec on cmd line
  vol_targ_subject = 'fsaverage'

# ---- design matrix ----
# read the Xg.dat
xg_path = os.path.join(glmdir, 'Xg.dat')
if not os.path.isfile(xg_path):
  errExit('cannot find Xg.dat in the glm directory')
with open(xg_path, 'r') as f:
  xg_content = f.readlines()
xg_content = [x.strip() for x in xg_content]
num_points = len(xg_content)
num_waves = len(xg_content[0].split())

# check if the design is a one-sample group mean (matrix is a column of all ones)
if num_waves == 1 and all([float(row) == 1.0 for row in xg_content]):
  print('Design detected as a one-sample group mean, adding -ise to implment sign flipping')
  extra_opts.append("'-ise'")

# write the palm design.mat
f = open(os.path.join(palmdir, 'design.mat'), "w")
f.write('/NumWaves {0}\n'.format(num_waves))
f.write('/NumPoints {0}\n'.format(num_points))
f.write('/Matrix\n')
for thing in xg_content:
  if len(thing.split()) != num_waves:
    errExit('number of waves (cols) in Xg.dat are not consistent')
  f.write(thing + '\n')
f.close()

# ---- contrasts ----
# read in the C.dat files
contrast_paths = glob('{0}/*/C.dat'.format(glmdir))
if len(contrast_paths) == 0:
  errExit('cannot find any contrast subdirs (with C.dat) in the glm directory')
contrasts = []
contrast_names = []
for constrast_file in contrast_paths:
  with open(constrast_file) as f:
    lines = f.read().splitlines()
  if len(lines) > 1:
    print('warning: excluding contrast file %s since it has multiple rows' % constrast_file)
  else:
    contrasts.append(lines[0])
    contrast_names.append(os.path.basename(os.path.dirname(constrast_file)))

# write the palm design.con
with open(os.path.join(palmdir, 'design.con'), "w") as f:
  f.write('/NumWaves {0}\n'.format(num_waves))
  f.write('/NumPoints {0}\n'.format(len(contrasts)))
  f.write('/Matrix\n')
  for contrast in contrasts:
    if len(contrast.split()) != num_waves:
      errExit('number of waves (cols) are not consistent')
    f.write(contrast + '\n')

# save a contrast-name lookup file
with open(os.path.join(palmdir, 'contrast_names.txt'), "w") as f:
  for i, con_name in enumerate(contrast_names):
    f.write('%s  %s\n' % (i+1, con_name))

# ---- mask ----
maskfile = ''
for fileformat in ['.mgz', '.mgh', '.nii', '.nii.gz']:
  if os.path.isfile(os.path.join(glmdir, 'mask' + fileformat)):
    maskfile = os.path.join(glmdir, 'mask' + fileformat)
    break
if not maskfile:
  errExit('cannot find any mask volume in the glm directory')
if maskfile.endswith('.nii.gz'):
  # nii.gz mask can't be used - must convert to nii
  maskfile = convertNifty(maskfile, os.path.join(palmdir, 'mask.nii'))

if input_file.endswith('.nii') and surface_analysis:
  print('ERROR: for surface based analysis, input cannot be nii or nii.gz');
  sys.exit(1);
if maskfile.endswith('.nii') and surface_analysis:
  print('ERROR: for surface based analysis, mask cannot be nii or nii.gz');
  sys.exit(1);

# cd to output dir
os.chdir(palmdir)

# create list of palm output files (should use same file format as inputs)
if len(contrast_names) == 1:
  fwep_file_list = ['fsp_clustere_tstat_fwep' + oformat]
  if surface_analysis: 
    dpv_file_list = ['fsp_dpv_tstat' + oformat]
  else:
    dpv_file_list = ['fsp_vox_tstat' + oformat]
else:
  tmpstr = 'fsp_clustere_tstat_fwep_c%d' + oformat
  fwep_file_list = [tmpstr % (i+1) for i in range(len(contrast_names))]
  if surface_analysis: 
    tmpstr = 'fsp_dpv_tstat_c%d' + oformat
  else:
    tmpstr = 'fsp_vox_tstat_c%d' + oformat
  dpv_file_list = [tmpstr % (i+1) for i in range(len(contrast_names))]


if args.pponly == False: # to test with post-processing only
  # delete any previous palm outfiles
  for fwep_file in fwep_file_list:
    if os.path.isfile(fwep_file): os.remove(fwep_file)
    
  # ---- matlab script ----
  # write the run_palm.m script
  script_path = os.path.join(palmdir, 'run_palm.m')
  with open(script_path, "w") as f:
    f.write("hemi = '%s';\n" % hemi)
    f.write("input = '%s';\n" % input_file)
    f.write("maskfile = '%s';\n" % maskfile)
    if surface_analysis: 
      f.write("surffile = '%s';\n" % surface_file)
      if(len(area_file)>0):
        f.write("areafile = '%s';\n" % area_file)
    f.write("\n")
    # Increase the CFT for two-tailed. PALM generates one-tailed
    # p-values.  Adding 0.301 to the sig threshold effectively halves
    # the p-value threshold Eg, sig=2 means p=.01, sig=2.301 means
    # p=.005. AW argues that this is added complication is no really
    # needed given that the CFT is arbitrary. I decided to do it anyway 
    # because this is what mri_glmfit-sim does.
    if args.twotail: 
      f.write("%% Changing threshold for two-tailed test\n");
      f.write("cft = %s + 0.301;\n" % args.cft)
    else:
      f.write("cft = %s;\n" % args.cft)
    f.write("pthresh = 10^-cft;\n");
    f.write("zthresh = fast_p2z(pthresh);\n")
    f.write("iters = %s;\n" % args.iters)
    f.write("\n")
    f.write("zthreshstr = sprintf('%f',zthresh);\n")
    f.write("itersstr = sprintf('%d',iters);\n")
    f.write("\n")
    f.write("palm('-i',input,'-m',maskfile,'-d','design.mat','-t','design.con','-logp',...\n")
    f.write("     '-n',itersstr,'-C',zthreshstr,'-o','fsp'")
    f.write("%s);\n"  % ''.join(",%s" % x for x in extra_opts))
    f.write("return;\n")
  
  if args.monly:
    print("'--monly' requested, so exiting without running...")
    print("to initiate palm processing, cd into %s and run 'run_palm.m' from matlab" % palmdir)
    sys.exit(0)
        
  # run the run_palm.m script
  print('palm prep complete - running the generated matlab script...')
  try:
    if args.octave:
      ret = call('octave --no-gui < run_palm.m', shell=True)
    else:
      ret = call('matlab -display iconic < run_palm.m', shell=True)
  except KeyboardInterrupt: errExit('run_palm.m interrupted')
            
# check for successful palm output since we can't get a return code from matlab
for fwep_file in fwep_file_list:
  if not os.path.isfile(fwep_file):
    errExit('cannot find expected palm output "%s" - check for run_palm.m failures' % fwep_file)

# ---- post-processing ----
cwpsig = -log10(args.cwp)

if surface_analysis:  # =============================================
  print('palm processing complete - running the surface cluster analysis...')
  for i, con_name in enumerate(contrast_names):
    fwep_file = fwep_file_list[i]
    # mri_surfcluster command
    cmd =  'mri_surfcluster --in %s --thmin %f' % (fwep_file, cwpsig)
    cmd += ' --subject %s --hemi %s' % (targ_subject, hemi)
    cmd += ' --annot aparc --sum %s.clustertable.summary --no-fixmni --surf white' % con_name
    cmd += ' --ocn %s.ocn%s --oannot ./%s.ocn.annot --sig2p-max' % (con_name, oformat, con_name)
    # additional options
    if args.centroid:   cmd += ' --centroid'
    if vars(args)['2spaces']: cmd += ' --bonferroni-max 2'
    if vars(args)['3spaces']: cmd += ' --bonferroni-max 3'
    print(os.getcwd())
    print(cmd)
    LF.write(os.getcwd());
    LF.write(cmd);
    LF.write("\n\n");
    ret = call(cmd, shell=True)
    if ret != 0: errExit('mri_surfcluster failed', ret)
    # Compute mean of input over clusters
    cmd =  'mri_segstats --seg %s.ocn%s --exclude 0' % (con_name,oformat)
    cmd += ' --i %s --avgwf %s.y.ocn.dat' % (input_file, con_name)
    print(os.getcwd())
    print(cmd)
    LF.write(os.getcwd());
    LF.write(cmd);
    LF.write("\n\n");
    ret = call(cmd, shell=True)
    if ret != 0: errExit('mri_segstats failed', ret)
    # This is a bit of a hack to get the vertex of the max vox. The mri_surfcluster above
    # will get the max vox from the cwp, which is the same at each vertex. The dpv is 
    # the uncorrected voxel-wise p-value (which is what is used in mri_glmfit-sim)
    # I don't think that the order of the clusters is guaranteed to be the same.
    # Need a better fix for this
    dpv_file = dpv_file_list[i]
    cmd =  'mri_surfcluster --in %s --thmin 0.5' % (dpv_file)
    cmd += ' --subject %s --hemi %s' % (targ_subject, hemi)
    cmd += ' --annot aparc --sum %s.dpv.clustertable.summary --no-fixmni --surf white' % con_name
    cmd += ' --mask %s.ocn%s ' % (con_name, oformat)
    print(os.getcwd())
    print(cmd)
    LF.write(os.getcwd());
    LF.write(cmd);
    LF.write("\n\n");
    ret = call(cmd, shell=True)
    if ret != 0: errExit('mri_surfcluster dpv failed, maybe 0 vertices found', ret)

else:  # volume analysis  ===========================================
  print('palm processing complete - running the volume cluster analysis...')
  for i, con_name in enumerate(contrast_names):
    fwep_file = fwep_file_list[i]
    # mri_volcluster command
    cmd =  'mri_volcluster --in %s --thmin %f' % (fwep_file, cwpsig)
    cmd += ' --seg %s aparc+aseg.mgz' % (vol_targ_subject)
    cmd += ' --sum  %s.clustertable.summary --no-fixmni ' % (con_name)
    cmd += ' --ocn %s.ocn%s --allowdiag --sig2p-max' % (con_name,oformat)
    cmd += ' --reg %s/subjects/%s/mri.2mm/reg.2mm.dat' % (freesurferhome, vol_targ_subject)
    # additional options
    # no centoid option in volume
    if vars(args)['2spaces']: cmd += ' --bonferroni-max 2'
    if vars(args)['3spaces']: cmd += ' --bonferroni-max 3'
    print(cmd)
    LF.write(cmd);
    LF.write("\n\n");
    ret = call(cmd, shell=True)
    if ret != 0: errExit('mri_volcluster failed', ret)
    ret = call(cmd, shell=True)
    if ret != 0: errExit('mri_volcluster failed', ret)
    # Compute mean of input over clusters
    cmd =  'mri_segstats --seg %s.ocn%s --exclude 0' % (con_name,oformat)
    cmd += ' --i %s --avgwf %s.y.ocn.dat' % (input_file, con_name)
    print(cmd)
    LF.write(cmd);
    LF.write("\n\n");
    ret = call(cmd, shell=True)
    if ret != 0: errExit('mri_segstats failed', ret)
    # See hack notes above for dpv
    dpv_file = dpv_file_list[i]
    cmd =  'mri_volcluster --in %s --thmin 0.5' % (dpv_file);
    cmd += ' --seg %s aparc+aseg.mgz' % (vol_targ_subject)
    cmd += ' --sum  %s.dpv.clustertable.summary --no-fixmni ' % (con_name)
    cmd += ' --mask %s.ocn%s --allowdiag ' % (con_name,oformat)
    cmd += ' --reg %s/subjects/%s/mri.2mm/reg.2mm.dat' % (freesurferhome, vol_targ_subject)
    print(cmd)
    LF.write(cmd);
    LF.write("\n\n");
    ret = call(cmd, shell=True)
    if ret != 0: errExit('mri_volcluster dpv failed', ret)

LF.close();
print('fspalm complete')

