#!/usr/bin/python3

import os
import time
import sys
import platform
import csv
import glob
import tempfile
import shutil
import platform
import argparse
import numpy as np
import surfa as sf
import scipy.ndimage


# defer tensorflow import until we need it (for faster command-line parsing)
tf = None


description = """
Segment subcortical limbic structures.

Input images can be provided by one of two methods. To segment one
or multiple T1-weighted images, use the --i flag to point to an
input image file or directory containing a series of images. The
--o flag should specify the corresponding output segmentation file
or directory. For example:

    mri_sclimbic_seg --i image.mgz --o seg.mgz

To process a series of freesurfer recon-all subjects, use the --s
input flag. When no arguments are provided to this flag, subjects
will be searched for in the 'subjects directory' defined by the
--sd flag or the SUBJECTS_DIR env variable. Otherwise, a set of
subject names can be specified as arguments. For example:

    mri_sclimbic_seg --s subj1 subj2 subj3

In freesurfer subject-mode, outputs will be saved to the subject's
mri and stats subdirectories, and volumetric stats will be computed
and saved automatically.
"""


# ------------------------------------------------------------------------------------------------
#                                         Main Entrypoint
# ------------------------------------------------------------------------------------------------


def main():

    # configure command-line
    parser = argparse.ArgumentParser(description=description)

    # normal-mode options
    parser.add_argument('-i', '--i', help='T1-w image(s) to segment. Can be a path to a single image or a directory of images.')
    parser.add_argument('-o', '--o', help='Segmentation output (required if --i is provided). Must be the same type as '
                                    'the input path (a single file or directory).')

    # subject-mode options
    parser.add_argument('-s', '--s', nargs='*', help='Process a series of freesurfer recon-all subjects (enables subject-mode).')
    parser.add_argument('--sd', help='Set the subjects directory (overrides the SUBJECTS_DIR env variable).')

    # general options
    parser.add_argument('--conform', action='store_true', help='Resample input to 1mm-iso; results will be put back in native resolution.')
    parser.add_argument('--etiv', action='store_true', help='deInclude eTIV in volume stats (enabled by default in subject-mode and --tal).')
    parser.add_argument('--tal', help='Alternative talairach xfm transform for estimating TIV. Can be file or suffix (for multiple inputs).')
    parser.add_argument('--write_posteriors', action='store_true', help='Save the label posteriors.')
    parser.add_argument('--write_volumes', action='store_true', help='Save label volume stats (enabled by default in subject-mode).')
    parser.add_argument('--write_qa_stats', action='store_true', help='Save QA stats (z and confidence).')
    parser.add_argument('--exclude', type=int, nargs='+', default=[], help='List of label IDs to exclude in any output stats files.')
    parser.add_argument('--keep_ac', action='store_true', help='Explicitly keep anterior commissure in the volume/qa files.')
    parser.add_argument('--vox-count-volumes', action='store_true', help='Use discrete voxel count for label volumes.')
    parser.add_argument('--model', help='Alternative model weights to load.')
    parser.add_argument('--ctab', help='Alternative color lookup table to embed in segmentation. Must be minimal, including 0, and sorted.')
    parser.add_argument('--population-stats', help='Alternative population volume stats for QA output.')
    parser.add_argument('--debug', action='store_true', help='Enable debug logging.')
    parser.add_argument('--vmp', action='store_true', help='Enable printing of vmpeak at the end.')
    parser.add_argument('--threads', type=int, default=1, help='Number of threads to use. Default is 1.')
    parser.add_argument('--features', type=int, default=24, help='Number of features (default is 24)')
    parser.add_argument('--7T', dest='sevenT', action='store_true', help='Preprocess 7T images (just sets percentile to 99.9).')
    parser.add_argument('--percentile', type=float, help='Use intensity percentile threshold for normalization.')
    parser.add_argument('--cuda-device', help='Cuda device for GPU support.')
    parser.add_argument('--output-base', default='sclimbic',help='String to use in output file name; default is sclimbic')
    parser.add_argument('--no-cite-sclimbic', dest='citeLimbic', action='store_false', help='Do not cite sclimbic paper at the end.')
    parser.add_argument('--logfile', help='Set logfile (default is mri_sclimbic.log)')
    parser.add_argument('--fov', type=int, default=160, help='Set FoV')
    # Ideally, we would get this from the model
    parser.add_argument('--nchannels', type=int, default=1,help='Number of channels')

    # check for no arguments
    if len(sys.argv) < 2:
        parser.print_help()
        sys.exit(1)

    # print out the command line
    print(' '.join(sys.argv))

    # parse commandline
    args = parser.parse_args()

    # a few sanity checks on the command-line inputs
    if args.i is None and args.s is None:
        sf.system.fatal('Input image(s) or subject(s) to segment must be provided with the --i or --s flags.')
    if args.i is not None and args.s is not None:
        sf.system.fatal('Cannot provide both input image (--i) and subject (--s) flags. Choose one input mode.')
    if args.i is not None and args.o is None:
        sf.system.fatal('--o output flag must be provided if --i input is used.')

    # Automatically exclude AntCom unless explicitly kept
    if not args.keep_ac: # not explicity being kept
        if not (853 in args.exclude): # not already in the list
            args.exclude.append(853) # add it to the list
    if not (853 in args.exclude): print('Keeping anterior commissure in vols and stats')

    if len(args.exclude) > 0: print("Excluding seg", args.exclude)

    if args.tal is not None: args.etiv = 1

    # check for fs home
    if not os.environ.get('FREESURFER_HOME'):
        sf.system.fatal('FREESURFER_HOME is not set. Please source FreeSurfer.')

    # configure cuda device
    if args.cuda_device is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_device

    cuda_device = os.getenv('CUDA_VISIBLE_DEVICES')
    if cuda_device is None or cuda_device == '-1':
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
        print('Using CPU')
    else:
        print('Using GPU device', cuda_device)

    # defer tensorflow importing until after parsing
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if args.debug else '3'
    global tf
    import tensorflow as tf
    if not args.debug:
        tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

    # set number of threads
    print('Using %d thread(s)' % args.threads)
    tf.config.threading.set_inter_op_parallelism_threads(args.threads)
    tf.config.threading.set_intra_op_parallelism_threads(args.threads)

    percentile = args.percentile
    if args.sevenT:
        print('7T image flag provided, using 99.9 percentile normalization')
        percentile = 99.9

    # load lookup table for segmentation
    lut_file = args.ctab if args.ctab else os.path.join(os.environ.get('FREESURFER_HOME'), 'models', 'sclimbic.ctab')
    labels = sf.load_label_lookup(lut_file)
    print('Loaded lookup table', lut_file)

    # load population stats for QA purposes
    if args.population_stats:
        pop_stats_file = args.population_stats
    else:
        pop_stats_file = os.path.join(os.environ.get('FREESURFER_HOME'), 'models', 'sclimbic.volstats.csv')
    with open(pop_stats_file, 'r') as csvfile:
        population_stats = {row.pop('VolStat_mm3'): row for row in csv.DictReader(csvfile)}
    print('Loaded population stats', pop_stats_file)

    # load model weights and initialize segmenter
    model_file = args.model if args.model else os.path.join(os.environ.get('FREESURFER_HOME'), 'models', 'sclimbic.fsm+ad.t1.nstd00-50.nstd32-50.h5')
    segmenter = LimbicSegmenter(model_file=model_file,
                                labels=labels,
                                population_stats=population_stats,
                                conform=args.conform,
                                store_etiv=args.etiv,
                                store_qa_stats=args.write_qa_stats,
                                debug=args.debug,
                                volumes_from_vox_count=args.vox_count_volumes,
                                percentile=percentile,
                                exclude=args.exclude,
                                inshape=(args.fov,args.fov,args.fov),
                                nfeatures=args.features,
                                nchannels=args.nchannels);
    print('Loaded model weights', model_file) 

    # loop through each image and segment
    if args.s is not None:

        # in freesurfer subject-mode, we can grab the subject's nu.mgz 
        # and use it as input
        sd = os.getenv('SUBJECTS_DIR') if args.sd is None else args.sd
        if sd is None:
            sf.system.fatal('Must set subjects directory with --sd or SUBJECTS_DIR env variable.')
        summary_file_prefix = os.path.join(sd, args.output_base + '_')
        print('Using subject directory', sd)

        # if no subjects have been provided, let's search the subjects directory
        subjects = args.s
        if len(subjects) == 0:
            nus = glob.glob(f'{sd}/*/mri/nu.mgz')
            subjects = [os.path.basename(n.replace('/mri/nu.mgz', '')) for n in nus]

        # we still haven't found anything
        if len(subjects) == 0:
            sf.system.fatal(f'Subjects directory {sd} does not contain any valid recon-all subjects.')

        # loop through subjects, set filenames, and segment
        for n, subj in enumerate(subjects):

            # sanity check the subject
            subjdir = os.path.join(sd, subj)
            if not os.path.isdir(subjdir):
                sf.system.fatal(f'Recon-all subject {subj} does not exist in {sd}.')

            # set default IO parameters
            params = {
                'input_file': os.path.join(subjdir, 'mri', 'nu.mgz'),
                'segmentation_path': os.path.join(subjdir, 'mri', args.output_base+'.mgz'),
                'volumes_path': os.path.join(subjdir, 'stats', args.output_base+'.stats'),
                'case_name': subj,
            }

            # estimate TIV from talairach lta
            if args.tal:
                lta = os.path.join(subjdir, args.tal)
            else:
                lta = os.path.join(subjdir, 'mri', 'transforms', 'talairach.xfm.lta')
            params['etiv'] = compute_etiv_from_lta(lta)
            print('Computed eTIV from talairach')

            # save posterior data to the mri subdir
            if args.write_posteriors:
                params['posteriors_path'] = os.path.join(subjdir, 'mri', args.output_base+'.posteriors.mgz')

            # save QA output to the stats subdir
            if args.write_qa_stats:
                params['qa_stats_path'] = os.path.join(subjdir, 'stats', args.output_base+'.qa.stats')

            print('\nSegmenting subject %s %d/%d' % (subj, n + 1, len(subjects)))
            segmenter.process_files(**params)

        if len(subjects) > 1:
            segmenter.write_all_case_volumes(summary_file_prefix + 'volumes_all.csv')

    else:
        # normal file/directory input mode

        # available file formats
        exts = ('.mgh', '.mgz', '.nii', '.nii.gz')

        isdir = os.path.isdir(args.i)
        if isdir:
            # find valid images in input directory
            input_files = []
            for ext in exts:
                input_files += sorted(glob.glob(os.path.join(args.i, '*' + ext)))
            if len(input_files) == 0:
                sf.system.fatal(f'Could not find any valid input images in {args.i}.')
            os.makedirs(args.o, exist_ok=True)
            if args.logfile is None: 
                logfile = os.path.join(args.o,'mri_sclimbic.log');
            else:
                logfile = os.path.join(args.o,args.logfile);
        else:
            if not args.i.endswith(exts):
                sf.system.fatal(f'{args.i} is an unsupported image file type.')
            input_files = [args.i]
            dirname = os.path.dirname(args.o)
            if dirname:
                os.makedirs(dirname, exist_ok=True)
            if args.logfile is None: 
                logfile = os.path.join(dirname,'mri_sclimbic.log');
            else:
                logfile = os.path.join(dirname,args.logfile);

        # Write the command line. Prob not the best way to do it.
        with open(logfile, 'w') as file:
            file.write('modelfile ' + model_file + '\n');
            file.write('ctab ' + lut_file + '\n');
            file.write('cd ' + os.getcwd() + '\n');
            file.write(' '.join(sys.argv) + '\n');

        # quick utility to add to filenames while keeping extension
        def split_extension(filename):
            for ext in exts:
                if filename.endswith(ext):
                    return (filename[:-len(ext)], ext)
            sf.system.fatal(f'{filename} is an unsupported image file type.')

        # loop through the images and segment
        for n, input_file in enumerate(input_files):
            params = {'input_file': input_file}

            # some logic to determine output filename
            if isdir:
                true_basename, ext = split_extension(os.path.basename(input_file))
                params['case_name'] = true_basename
                basename = os.path.join(args.o, true_basename + '.'+args.output_base)
                params['segmentation_path'] = f'{basename}{ext}'
            else:
                basename, ext = split_extension(args.o)
                params['case_name'] = basename
                params['segmentation_path'] = args.o

            # optional outputs
            if args.write_posteriors:
                params['posteriors_path'] = f'{basename}.posteriors{ext}'            
            if args.write_volumes:
                params['volumes_path'] = f'{basename}.stats'
            if args.write_qa_stats:
                params['qa_stats_path'] = f'{basename}.qa.stats'

            # optional TIV estimation (really only applies to volume stats)
            if args.write_volumes and args.etiv:
                if args.tal:
                    if isdir:
                        xfm = os.path.join(args.i, true_basename + args.tal)
                    else:
                        xfm = args.tal
                    params['etiv'] = compute_etiv_from_lta(xfm)
                    print('Computed eTIV from talairach file')
                else:
                    print("Computing etiv from scratch")
                    params['etiv'] = compute_etiv_from_scratch(input_file)

            print('\nSegmenting image %d/%d' % (n + 1, len(input_files)))
            segmenter.process_files(**params)

        if isdir:
            summary_file_prefix = os.path.join(args.o, args.output_base+ '_')
        else:
            summary_file_prefix = f'{basename}.'

        # write all case volumes in output directory
        if os.path.isdir(args.o) and len(input_files) > 0:
            segmenter.write_all_case_volumes(summary_file_prefix + 'volumes_all.csv')

    # write qa stats
    if args.write_qa_stats:
        segmenter.write_all_case_zscores(summary_file_prefix + 'zqa_scores_all.csv')
        segmenter.write_all_case_confidences(summary_file_prefix + 'confidences_all.csv')

    # all done
    if(args.citeLimbic):
        print('\nIf you use this tool in a publication, please cite:')
        print('A Deep Learning Toolbox for Automatic Segmentation of Subcortical Limbic Structures from MRI Images');
        print('Greve, DN, Billot, B, Cordero, D, Hoopes, M. Hoffmann, A, Dalca, A, Fischl, B,  Iglesias, JE, Augustinack, JC')
        print('2021, Neuroimage. 10.1016/j.neuroimage.2021.118610. PMID: 34571161. https://pubmed.ncbi.nlm.nih.gov/34571161.')

    # check memory usage
    if args.debug or args.vmp:
        print_vm_peak()

    print('done');

# ------------------------------------------------------------------------------------------------
#                                         LimbicSegmenter
# ------------------------------------------------------------------------------------------------


class LimbicSegmenter:
    """
    Isolated class to handle image IO, preprocessing, prediction, and postprocessing
    """

    def __init__(self,
                 model_file,
                 labels,
                 population_stats,
                 conform=True,
                 store_etiv=False,
                 store_qa_stats=False,
                 debug=False,
                 volumes_from_vox_count=False,
                 percentile=None,
                 exclude=[],
                 inshape=(160,160,160),
                 nfeatures=24,
                 nchannels=1):

        self.labels = labels
        self.population_stats = population_stats
        self.conform = conform
        self.inshape = inshape
        self.case_volumes = {}
        self.case_etivs = {}
        self.case_prob_means = {}
        self.store_etiv = store_etiv
        self.store_qa_stats = store_qa_stats
        self.last_time = time.time()
        self.debug = debug
        self.volumes_from_vox_count = volumes_from_vox_count
        self.percentile = percentile
        self.nchannels = nchannels;
        self.nfeatures = nfeatures;

        # build mask of labels to exclude in stats output files
        # always ignore unknown label
        self.exclude = [0] + exclude
        self.exclude_mask = [sid not in self.exclude for sid in labels.keys()]
        self.label_names = [label.name for i, label in zip(self.exclude_mask, self.labels.values()) if i]
        print(f'nb_labels {len(self.labels)}');
        # build and load model
        print(f'inshape {self.inshape} features {nfeatures}')
        self.model = unet(nb_features=nfeatures,
                          input_shape=(*self.inshape, self.nchannels),
                          nb_levels=3,
                          conv_size=3,
                          nb_labels=len(self.labels),
                          name='unet',
                          prefix=None,
                          feat_mult=2,
                          pool_size=2,
                          padding='same',
                          dilation_rate_mult=1,
                          activation='elu',
                          use_residuals=False,
                          final_pred_activation='softmax',
                          nb_conv_per_level=2,
                          layer_nb_feats=None,
                          conv_dropout=0,
                          batch_norm=-1,
                          input_model=None)
        print(f'  {self.model.layers[0].input_shape[0]}');
        print(f"Loading weights from {model_file} -----------------------");
        self.model.load_weights(model_file, by_name=True)
        print("Done loading weights -----------------------");

    def reset_timer(self):
        """
        Reset internal timer.
        """
        self.last_time = time.time()

    def print_time(self, message):
        """
        Print timer time if debugging is enabled.
        """
        if self.debug:
            print('%s: %.4f s' % (message, time.time() - self.last_time))

    def write_all_case_volumes(self, path):
        """
        Write all case volumes to a csv.
        """
        header = ['case'] + self.label_names
        if self.store_etiv:
            header.append('eTIV')
        with open(path, 'w') as file:
            file.write(','.join(header) + '\n')
            for case, volumes in self.case_volumes.items():
                volumes = volumes[self.exclude_mask]
                if self.store_etiv:
                    volumes = np.append(volumes, self.case_etivs[case])
                file.write(','.join([case] + ['%.4f' % v for v in volumes]))
                file.write('\n')
        print('\nWrote summary of label volumes to', path)

    def write_all_case_zscores(self, path):
        """
        Write all case z-scores to a csv.
        """
        stats = [self.population_stats.get(label) for label in self.label_names]
        stat_mask = [idx for idx, stat in enumerate(stats) if stat is not None]
        labels = [label for label, stat in zip(self.label_names, stats) if stat is not None]
        mean = [float(stat['mean']) for stat in stats if stat is not None]
        std = [float(stat['std']) for stat in stats if stat is not None]
        with open(path, 'w') as file:
            file.write(','.join(['case'] + labels) + '\n')
            for case, volumes in self.case_volumes.items():
                vol = volumes[self.exclude_mask][stat_mask]
                zscores = (vol - mean) / std
                file.write(','.join([case] + ['%.4f' % z for z in zscores]) + '\n')
        print('Wrote summary of label z-scores to', path)

    def write_all_case_confidences(self, path):
        """
        Write all case confidences (mean prediction prob) to a csv.
        """
        with open(path, 'w') as file:
            file.write(','.join(['case'] + self.label_names) + '\n')
            for case, prob_means in self.case_prob_means.items():
                file.write(','.join([case] + ['%.4f' % v for v in prob_means[self.exclude_mask]]) + '\n')
        print('Wrote summary of label prediction confidences to', path)

    def preprocess(self, image):
        """
        Preprocess an image by conforming it to the correct orientation, shape, and scale.
        """        
        # check resolution
        if not np.allclose(image.geom.voxsize, (1, 1, 1), rtol=0, atol=1e-2):
            image_geom_voxsize = [f'{x:4.2f}' for x in image.geom.voxsize]
            if self.conform:
                print(f'The input image has resolution {image_geom_voxsize} mm, but 1mm-isotropic input is required.\n'
                    'However, --conform has been specified, so the volume will be resliced to 1mm iso.\n')
            else:
                print('')
                sf.system.fatal(f'The input image has resolution {image_geom_voxsize}, but 1mm-isotropic input is required.\n'
                         'The volume can be resliced to 1mm-iso by specifying --conform (results may suffer).\n')

        # check channels
        if image.nframes != self.nchannels:
            sf.system.fatal(f'Input image has {image.nframes}, expecting {self.nchannels}.')

        # normalize image data
        dmin = image.min()
        dmax = image.max() if self.percentile is None else image.percentile(self.percentile, nonzero=True)
        if dmin == dmax:
            sf.system.fatal('Input image is blank!')
        image = (image.astype(np.float32) - dmin) / (dmax - dmin)
        image = image.clip(0, 1)

        # conform to RAS 1mm space
        processed = image.conform(shape=(*self.inshape, self.nchannels), voxsize=1.0, orientation='RAS', dtype='float32', copy=False)
        return processed

    def segment(self, image):
        """
        Segment a raw input image.
        """
        self.reset_timer()
        conformed = self.preprocess(image)
        self.print_time('Preprocess time')

        # posterior prediction
        self.reset_timer()
        prediction = self.model.predict(conformed.framed_data[np.newaxis]).squeeze()
        self.print_time('Prediction time')
        
        self.reset_timer()

        # let's clean up the posteriors a bit, but we'll do this in
        # a minimal cropped space to speed things up
        seg = conformed.new(prediction.argmax(-1))
        bbox = seg.bbox(margin=2)
        cropped_seg = seg[bbox]
        posteriors = conformed.new(prediction)[bbox]

        # mask the posteriors around each label
        dilate_struct = build_binary_structure(1, 3)
        for label in range(1, len(self.labels)):
            cropped_pred_label = posteriors.data[..., label]
            label_mask = scipy.ndimage.binary_dilation(cropped_seg.data == label, dilate_struct)
            cropped_pred_label[np.logical_not(label_mask)] = 0
            posteriors[..., label] = cropped_pred_label

        # ensure that the posteriors sum to 1 in the cropped space
        if(len(self.labels) > 2):
            posteriors[..., 0] = 1.0 - np.sum(posteriors[..., 1:], axis=-1)
        else: # There is only one non-unknown label
            posteriors[..., 0] = 1.0 - posteriors[..., 1]
        posteriors = posteriors.clip(0, 1)
        posteriors /= np.sum(posteriors, axis=-1, keepdims=True)

        # resample cropped posteriors to original resolution
        posteriors = posteriors.resize(image.geom.voxsize, copy=False)

        # compute the final hard segmentation and compute voxel counts while we're at it
        vox_counts = []
        mean_probs = []
        argmax = posteriors.data.argmax(axis=-1)
        segmap = np.zeros(argmax.shape, dtype='int32')
        for n, nid in enumerate(self.labels.keys()):
            label_mask = argmax == n
            segmap[label_mask] = nid
            vox_counts.append(np.count_nonzero(label_mask))
            if self.store_qa_stats:
                probs = posteriors.data[..., n][label_mask]
                mean_probs.append(probs.mean() if len(probs) > 0 else 0.0)
        vox_counts = np.array(vox_counts)
        mean_probs = np.array(mean_probs)

        # compute label volumes in original resolution
        voxvol = np.prod(posteriors.geom.voxsize)
        if self.volumes_from_vox_count:
            volumes = voxvol * np.array(vox_counts)
        else:
            volumes = voxvol * posteriors.data.reshape(-1, posteriors.shape[-1]).sum(0)

        # resample final hard segmentation to original space
        segmentation = posteriors.new(segmap).resample_like(image, method='nearest')
        segmentation.labels = self.labels

        self.print_time('Postprocess time')

        return (posteriors, segmentation, vox_counts, volumes, mean_probs)

    def process_files(
        self,
        input_file,
        segmentation_path,
        posteriors_path=None,
        volumes_path=None,
        qa_stats_path=None,
        etiv=None,
        case_name=None):

        # load image
        if not os.path.isfile(input_file):
            sf.system.fatal(f'Input image {input_file} does not exist')
        image = sf.load_volume(input_file)
        print('Loaded input image from', input_file)

        # segment
        post, seg, vox_counts, volumes, mean_probs = self.segment(image)

        # write segmentation
        seg.save(segmentation_path)
        print('Wrote segmentation to', segmentation_path)

        # write posteriors
        if posteriors_path is not None:
            post.save(posteriors_path)
            print('Wrote posteriors to', posteriors_path)

        # write volume stats in FS format
        if volumes_path is not None:
            with open(volumes_path, 'w') as file:
                file.write('# Subcortical Limbic Volumetric Stats\n')
                file.write('# Created by mri_sclimbic_seg\n')
                if etiv is not None:
                    file.write('# Measure EstimatedTotalIntraCranialVol, eTIV, Estimated ' + \
                              f'Total Intracranial Volume, {etiv:.6f}, mm^3\n')
                label_matches = [(vid, nid) for (vid, nid) in enumerate(self.labels.keys()) if nid not in self.exclude]
                file.write(f'# NRows {len(label_matches)}\n')
                file.write('# NTableCols 5\n')
                file.write('# ColHeaders Index SegId NVoxels Volume_mm3 StructName\n')
                for n, (vid, nid) in enumerate(label_matches):
                    file.write(f'{n+1: <4} {nid: >6}{vox_counts[vid]: >6}{volumes[vid]: >12.4f}    {self.labels[nid].name}\n')

            print('Wrote volume stats to', volumes_path)

        # store label volumes
        self.case_volumes[case_name] = volumes
        if self.store_etiv:
            self.case_etivs[case_name] = etiv

        # write mean probs
        if self.store_qa_stats:
            self.case_prob_means[case_name] = mean_probs


# ------------------------------------------------------------------------------------------------
#                                             Utilities
# ------------------------------------------------------------------------------------------------


def compute_etiv_from_lta(lta):
    """
    Compute eTIV by loading the image or subject's talairach lta.
    """
    scale_factor = 1948.106
    etiv = 1e3 * scale_factor / sf.load_affine(lta).det()
    return etiv


def compute_etiv_from_scratch(image):
    """
    Compute eTIV by conforming, normalizing, and registering the input image to
    talairach space. This will slow down processing substantially.
    """

    # make a temporary directory for the intermediate outputs
    tmpdir = tempfile.mkdtemp()
    norm = os.path.join(tmpdir, 'nu.mgz')
    xfm = os.path.join(tmpdir, 'talairach.xfm')
    lta = os.path.join(tmpdir, 'talairach.xfm.lta')

    # conform the input image
    ret = sf.system.run(f'mri_convert --conform {image} {norm}')
    if ret != 0:
        sf.system.fatal('mri_convert --conform failed!')

    # run intensity normalization
    ret = sf.system.run(f'mri_nu_correct.mni --no-rescale --i {norm} --o {norm} --n 1 --proto-iters 1000 --distance 50 --ants-n4')
    if ret != 0:
        sf.system.fatal('mri_nu_correct failed!')

    # run talairach registration
    ret = sf.system.run(f'talairach_avi --i {norm} --xfm {xfm}')
    if ret != 0:
        sf.system.fatal('talairach_avi failed!')

    # convert XFM to LTA
    mni305 = os.path.join(os.environ.get('FREESURFER_HOME'), 'average', 'mni305.cor.mgz')
    ret = sf.system.run(f'lta_convert --src {norm} --trg {mni305} --inxfm {xfm} --outlta {lta} --subject fsaverage --ltavox2vox')
    if ret != 0:
        sf.system.fatal('lta_convert failed!')

    # estimate TIV
    etiv = compute_etiv_from_lta(lta)
    shutil.rmtree(tmpdir)
    return etiv


def print_vm_peak():
    """
    Print the VM peak of the running process. This is only available
    on linux platforms.
    """
    if platform.system() != 'Linux':
        return
    procstat = os.path.join('/proc', str(os.getpid()), 'status')
    fp = open(procstat, 'r')
    lines = fp.readlines()
    for line in lines:
        strs = line.split()
        if(len(strs) < 3):
            continue
        if(strs[0] != 'VmPeak:'):
            continue
        print('vmpcma:', int(strs[1]))


def build_binary_structure(connectivity, n_dims):
    """
    Return a dilation element with provided connectivity.
    """
    shape = [connectivity * 2 + 1] * n_dims
    dist = np.ones(shape)
    center = tuple([tuple([int(s / 2)]) for s in shape])
    dist[center] = 0
    dist = scipy.ndimage.distance_transform_edt(dist)
    struct = (dist <= connectivity) * 1
    return struct


# ------------------------------------------------------------------------------------------------
#                         Neurite Utilities - See github.com/adalca/neurite
# ------------------------------------------------------------------------------------------------


def unet(nb_features,
         input_shape,
         nb_levels,
         conv_size,
         nb_labels,
         name='unet',
         prefix=None,
         feat_mult=1,
         pool_size=2,
         padding='same',
         dilation_rate_mult=1,
         activation='elu',
         use_residuals=False,
         final_pred_activation='softmax',
         nb_conv_per_level=1,
         layer_nb_feats=None,
         conv_dropout=0,
         batch_norm=None,
         input_model=None):
    """
    Unet-style tf.keras model with an overdose of parametrization. Copied with permission
    from github.com/adalca/neurite.
    """

    # naming
    model_name = name
    if prefix is None:
        prefix = model_name

    # volume size data
    ndims = len(input_shape) - 1
    if isinstance(pool_size, int):
        pool_size = (pool_size,) * ndims

    # get encoding model
    enc_model = conv_enc(nb_features,
                         input_shape,
                         nb_levels,
                         conv_size,
                         name=model_name,
                         prefix=prefix,
                         feat_mult=feat_mult,
                         pool_size=pool_size,
                         padding=padding,
                         dilation_rate_mult=dilation_rate_mult,
                         activation=activation,
                         use_residuals=use_residuals,
                         nb_conv_per_level=nb_conv_per_level,
                         layer_nb_feats=layer_nb_feats,
                         conv_dropout=conv_dropout,
                         batch_norm=batch_norm,
                         input_model=input_model)

    # get decoder
    # use_skip_connections=True makes it a u-net
    lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None
    dec_model = conv_dec(nb_features,
                         None,
                         nb_levels,
                         conv_size,
                         nb_labels,
                         name=model_name,
                         prefix=prefix,
                         feat_mult=feat_mult,
                         pool_size=pool_size,
                         use_skip_connections=True,
                         padding=padding,
                         dilation_rate_mult=dilation_rate_mult,
                         activation=activation,
                         use_residuals=use_residuals,
                         final_pred_activation=final_pred_activation,
                         nb_conv_per_level=nb_conv_per_level,
                         batch_norm=batch_norm,
                         layer_nb_feats=lnf,
                         conv_dropout=conv_dropout,
                         input_model=enc_model)
    final_model = dec_model

    return final_model


def conv_enc(nb_features,
             input_shape,
             nb_levels,
             conv_size,
             name=None,
             prefix=None,
             feat_mult=1,
             pool_size=2,
             dilation_rate_mult=1,
             padding='same',
             activation='elu',
             layer_nb_feats=None,
             use_residuals=False,
             nb_conv_per_level=2,
             conv_dropout=0,
             batch_norm=None,
             input_model=None):
    """
    Fully Convolutional Encoder. Copied with permission from github.com/adalca/neurite.
    """

    # naming
    model_name = name
    if prefix is None:
        prefix = model_name

    # first layer: input
    name = '%s_input' % prefix
    if input_model is None:
        input_tensor = tf.keras.layers.Input(shape=input_shape, name=name)
        last_tensor = input_tensor
    else:
        input_tensor = input_model.inputs
        last_tensor = input_model.outputs
        if isinstance(last_tensor, list):
            last_tensor = last_tensor[0]
        last_tensor = tf.keras.layers.Reshape(input_shape)(last_tensor)
        input_shape = last_tensor.shape.as_list()[1:]

    # volume size data
    ndims = len(input_shape) - 1
    input_shape = tuple(input_shape)
    if isinstance(pool_size, int):
        pool_size = (pool_size,) * ndims

    # prepare layers
    convL = getattr(tf.keras.layers, 'Conv%dD' % ndims)
    conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'}
    maxpool = getattr(tf.keras.layers, 'MaxPooling%dD' % ndims)
    fused_batch_norm = False if (platform.system() == 'Darwin' and platform.machine() == 'arm64') else None

    # down arm:
    # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers
    lfidx = 0  # level feature index
    for level in range(nb_levels):
        lvl_first_tensor = last_tensor
        nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int)
        conv_kwargs['dilation_rate'] = dilation_rate_mult ** level

        for conv in range(nb_conv_per_level):  # does several conv per level, max pooling applied at the end
            if layer_nb_feats is not None:  # None or List of all the feature numbers
                nb_lvl_feats = layer_nb_feats[lfidx]
                lfidx += 1

            name = '%s_conv_downarm_%d_%d' % (prefix, level, conv)
            if conv < (nb_conv_per_level - 1) or (not use_residuals):
                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
            else:  # no activation
                last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)

            if conv_dropout > 0:
                # conv dropout along feature space only
                name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv)
                noise_shape = [None, *[1] * ndims, nb_lvl_feats]
                last_tensor = tf.keras.layers.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor)

        if use_residuals:
            convarm_layer = last_tensor

            # the "add" layer is the original input
            # However, it may not have the right number of features to be added
            nb_feats_in = lvl_first_tensor.get_shape()[-1]
            nb_feats_out = convarm_layer.get_shape()[-1]
            add_layer = lvl_first_tensor
            if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
                name = '%s_expand_down_merge_%d' % (prefix, level)
                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor)
                add_layer = last_tensor

                if conv_dropout > 0:
                    name = '%s_dropout_down_merge_%d_%d' % (prefix, level, conv)
                    noise_shape = [None, *[1] * ndims, nb_lvl_feats]
                    last_tensor = tf.keras.layers.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor)

            name = '%s_res_down_merge_%d' % (prefix, level)
            last_tensor = tf.keras.layers.add([add_layer, convarm_layer], name=name)

            name = '%s_res_down_merge_act_%d' % (prefix, level)
            last_tensor = tf.keras.layers.Activation(activation, name=name)(last_tensor)

        if batch_norm is not None:
            name = '%s_bn_down_%d' % (prefix, level)
            last_tensor = tf.keras.layers.BatchNormalization(axis=batch_norm, name=name, fused=fused_batch_norm)(last_tensor)

        # max pool if we're not at the last level
        if level < (nb_levels - 1):
            name = '%s_maxpool_%d' % (prefix, level)
            last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor)

    # create the model and return
    model = tf.keras.Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
    return model


def conv_dec(nb_features,
             input_shape,
             nb_levels,
             conv_size,
             nb_labels,
             name=None,
             prefix=None,
             feat_mult=1,
             pool_size=2,
             use_skip_connections=False,
             padding='same',
             dilation_rate_mult=1,
             activation='elu',
             use_residuals=False,
             final_pred_activation='softmax',
             nb_conv_per_level=2,
             layer_nb_feats=None,
             batch_norm=None,
             conv_dropout=0,
             input_model=None):
    """
    Fully Convolutional Decoder. Copied with permission from github.com/adalca/neurite.

    Parameters:
        ...
        use_skip_connections (bool): if true, turns an Enc-Dec to a U-Net.
            If true, input_tensor and tensors are required.
            It assumes a particular naming of layers. conv_enc...
    """

    # naming
    model_name = name
    if prefix is None:
        prefix = model_name

    # if using skip connections, make sure need to use them.
    if use_skip_connections:
        assert input_model is not None, "is using skip connections, tensors dictionary is required"

    # first layer: input
    input_name = '%s_input' % prefix
    if input_model is None:
        input_tensor = tf.keras.layers.Input(shape=input_shape, name=input_name)
        last_tensor = input_tensor
    else:
        input_tensor = input_model.input
        last_tensor = input_model.output
        input_shape = last_tensor.shape.as_list()[1:]

    # vol size info
    ndims = len(input_shape) - 1
    input_shape = tuple(input_shape)
    if isinstance(pool_size, int):
        if ndims > 1:
            pool_size = (pool_size,) * ndims

    # prepare layers
    convL = getattr(tf.keras.layers, 'Conv%dD' % ndims)
    conv_kwargs = {'padding': padding, 'activation': activation}
    upsample = getattr(tf.keras.layers, 'UpSampling%dD' % ndims)
    fused_batch_norm = False if (platform.system() == 'Darwin' and platform.machine() == 'arm64') else None    

    # up arm:
    # nb_levels - 1 layers of Deconvolution3D
    #    (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu
    lfidx = 0
    for level in range(nb_levels - 1):
        nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int)
        conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level)

        # upsample matching the max pooling layers size
        name = '%s_up_%d' % (prefix, nb_levels + level)
        last_tensor = upsample(size=pool_size, name=name)(last_tensor)
        up_tensor = last_tensor

        # merge layers combining previous layer
        if use_skip_connections:
            conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1)
            cat_tensor = input_model.get_layer(conv_name).output
            name = '%s_merge_%d' % (prefix, nb_levels + level)
            last_tensor = tf.keras.layers.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name)

        # convolution layers
        for conv in range(nb_conv_per_level):
            if layer_nb_feats is not None:
                nb_lvl_feats = layer_nb_feats[lfidx]
                lfidx += 1

            name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv)
            if conv < (nb_conv_per_level - 1) or (not use_residuals):
                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
            else:
                last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)

            if conv_dropout > 0:
                name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv)
                noise_shape = [None, *[1] * ndims, nb_lvl_feats]
                last_tensor = tf.keras.layers.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor)

        # residual block
        if use_residuals:

            # the "add" layer is the original input
            # However, it may not have the right number of features to be added
            add_layer = up_tensor
            nb_feats_in = add_layer.get_shape()[-1]
            nb_feats_out = last_tensor.get_shape()[-1]
            if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
                name = '%s_expand_up_merge_%d' % (prefix, level)
                add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer)

                if conv_dropout > 0:
                    name = '%s_dropout_up_merge_%d_%d' % (prefix, level, conv)
                    noise_shape = [None, *[1] * ndims, nb_lvl_feats]
                    last_tensor = tf.keras.layers.Dropout(conv_dropout, noise_shape=noise_shape)(last_tensor)

            name = '%s_res_up_merge_%d' % (prefix, level)
            last_tensor = tf.keras.layers.add([last_tensor, add_layer], name=name)

            name = '%s_res_up_merge_act_%d' % (prefix, level)
            last_tensor = tf.keras.layers.Activation(activation, name=name)(last_tensor)

        if batch_norm is not None:
            name = '%s_bn_up_%d' % (prefix, level)
            last_tensor = tf.keras.layers.BatchNormalization(axis=batch_norm, name=name, fused=fused_batch_norm)(last_tensor)

    # Compute likelyhood prediction (no activation yet)
    name = '%s_likelihood' % prefix
    last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor)
    like_tensor = last_tensor

    # output prediction layer
    # we use a softmax to compute P(L_x|I) where x is each location
    if final_pred_activation == 'softmax':
        # print("using final_pred_activation %s for %s" % (final_pred_activation, model_name))
        name = '%s_prediction' % prefix
        softmax_lambda_fcn = lambda x: tf.keras.activations.softmax(x, axis=ndims + 1)
        pred_tensor = tf.keras.layers.Lambda(softmax_lambda_fcn, name=name)(last_tensor)

    # otherwise create a layer that does nothing.
    else:
        name = '%s_prediction' % prefix
        pred_tensor = tf.keras.layers.Activation('linear', name=name)(like_tensor)

    # create the model and retun
    model = tf.keras.Model(inputs=input_tensor, outputs=pred_tensor, name=model_name)
    return model


if __name__ == '__main__':
    main()
