#!/usr/bin/python3

import time
import os
import shutil
import time
import sys
import random
import argparse
import glob
from typing import List, Optional

from math import prod
from scipy import ndimage

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import numpy as np

import surfa as sf
from surfa.system import fatal as sfs_fatal


description = """
Segmentation of pitutiary and pineal glands. Produces the following labels:
Pituitary-Ant (883), 
Pituitary-Post (903),
Infundibulum (904), and
Pineal (900).

The code has two input modes:
(1) Normal mode, which requires one or more filenames or directories (--i) 
and optional output filenames or directory (--o) (EX: mri_pglands_seg --i 
<infile1> ... <infileN> --o <outdir>).
(2) FreeSurfer mode, which requires either a list of subjects (--s) and/or the
environment variable SUBJECTS_DIR (--sd) (EX: mri_pglands_seg --s <subject1>
... < subjectN> --sd <SUBJECTS_DIR>). In FreeSurfer mode, output paths will be
set automatically in the FreeSurfer filename/directory convention, but the user
can override this by specifying an output directory if desired.
"""


#-----------------------------------------------------------------------------#
#                               Main entry point                              #
#-----------------------------------------------------------------------------#

def main():
    # Test that FREESURFER_HOME has been properly configured
    if not os.environ.get('FREESURFER_HOME'):
        sfs_fatal('FREESURFER_HOME is not set. Please source FreeSurfer.')
    # Set up argument parser
    parser = argparse.ArgumentParser(description=description)
    parser = _define_args(parser)
    
    if len(sys.argv) < 2:
        parser.print_help()
        sys.exit(1)

    # Parse commandline
    pargs = parser.parse_args()

    # Sanity checks
    if pargs.i is None and pargs.s is None and pargs.sd is None:
        sfs_fatal('Must provide at least  either: input image/directory '
                  '(--i), list of Freesurfer subjects (--s), or Freesurfer '
                  'subjects directory (--sd)')
        
    if pargs.i is not None and pargs.s is not None:
        sfs_fatal('Cannot provide both input image (--i) and subject (--s) '
                  'flags. Choose one input mode.')
        
    # Configure device
    torch.set_float32_matmul_precision('medium')
    device = torch.device('cpu')
    
    if pargs.use_cuda:
        if torch.cuda.is_available():
            device = torch.device('cuda')
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            torch.autograd.set_detect_anomaly(True)
        else:
            print('CUDA is unavailable, running w/ cpu')
    
    # Prepare data
    mode = 'FS' if pargs.i is None else 'normal'
    include_etiv = True if pargs.etiv or mode == 'FS' else False
    
    paths_dict = get_filenames(mode=mode,
                               FS_subject_ids=pargs.s,
                               FS_subjects_dir=pargs.sd,
                               inpath=pargs.i,
                               outpath=pargs.o,
                               outbase=pargs.outbase,
                               talaff=pargs.tal,
                               mniaff=pargs.mni_template_transform,
                               qawarp=pargs.qa_transform,
    )
    n_images = len(paths_dict['input_volumes'])
    
    # Initialize segmenter
    segmenter = PGlandsSegmenter(paths_dict=paths_dict,
                                 model_path=pargs.model,
                                 lut_path=pargs.lut,
                                 template_path=pargs.mni_template,
                                 crop_size=pargs.crop_patch_size,
                                 use_center_crop=pargs.center_crop,
                                 norm_percent=pargs.robust_norm_percent,
                                 write_posts=pargs.write_posteriors,
                                 write_vols=pargs.write_vol_stats,
                                 write_qas=pargs.write_qa_stats,
                                 include_etiv=include_etiv,
                                 device=device
    )

    # Loop through images
    for idx in range(n_images):
        segmenter(idx)
        
    

#------------------------------------------------------------------------------

def _define_args(parser):
    # Set
    parser.set_defaults(accelerator='gpu')
    parser.set_defaults(devices=1)
    parser.set_defaults(num_sanity_val_steps=0)
    parser.set_defaults(deterministic=False)

    # I/O
    parser.add_argument('-i', '--i', nargs='*', help='T1-w image(s) to '
                        'segment. Can be either a single directory or path(s) '
                        'to one or more images.')
    parser.add_argument('-s', '--s', nargs='*', help='Process a series of FS '
                        'recon-all subjects (enables subject-mode).')
    parser.add_argument('--sd', help='Set the subjects directory (overrides '
                        'the SUBJECTS_DIR env variable).')
    parser.add_argument('-o', '--o', nargs='*', help='Segmentation output '
                        '(optional). Can either be a single directory or '
                        'path(s) to one or more file. If input (--i) is a '
                        'directory, then the output (--o) must also be a '
                        'directory. If the input (--i) is one or more '
                        'filenames, then the output (--o) must be either a '
                        'single directory or the same number of file paths. '
                        'If running in FreeSurfer mode (e.g. inputs are given '
                        'with --s and/or --sd), then the output (--o) must be '
                        'a directory. Note that if the provided output is a '
                        'directory, a subdirectory will automatically be '
                        'created for each input image and named using the '
                        'input image file basename (w/o extension). If no '
                        'output is provided, then the output paths will be '
                        'parsed from the input paths (in both normal and '
                        'FreeSurfer mode).')
    
    # General options
    default_path = os.path.join(
        os.environ.get('FREESURFER_HOME_FSPYTHON'), 'models/pglands_seg'
    )
    parser.add_argument('--use_cuda', action='store_true',
                        help='Use Cuda for GPU support')
    parser.add_argument('--center_crop', action='store_true',
                        help='Use image center for crop window (and skip '
                        'affine registration to mni152 space)')
    parser.add_argument('--crop_patch_size', type=int, default=96,
                        help='isotropic size (# of voxels) for crop patch; '
                        'default is 96')
    parser.add_argument('--logfile', type=str,
                        help='Set logfile basename; default is '
                        'mri_pglands_seg.log (and automatically writes to the '
                        'output dir)')
    parser.add_argument('--lut', type=str,
                        default=os.path.join(default_path, 'pglands.ctab'),
                        help='Path to lut; default is pglands.ctab')
    parser.add_argument('--mni_template', type=str,
                        default=os.path.join(
                            default_path, 'mni152_label_template.mgz'),
                        help='MNI label template used to set crop window; '
                        'default is mni152_label_template.mgz')
    parser.add_argument('--mni_template_transform', type=str,
                        help='Path to affine transform for registering the '
                        'mni label template to subject space. This will '
                        'override the default syntax, and will apply the '
                        'transform to all subjects.')
    parser.add_argument('--model', type=str, 
                        default=os.path.join(default_path, 'pglands_seg.pth'),
                        help='Path to trained model; default is '
                        'pglands_seg.pth')
    parser.add_argument('--outbase', type=str, default='pglands',
                        help='String to use in output filename; '
                        'default is "pglands"')
    parser.add_argument('--robust_norm_percent', type=float, default=0.95,
                        help='Percentile for robust normalization of input '
                        'image intensities; default is 0.95')
    
    # Other outputs (and associated args)
    parser.add_argument('--write_posteriors', action='store_true',
                        help='Flag to save label posteriors')
    parser.add_argument('--write_qa_stats', action='store_true',
                        help='Flag to calculate and write QA stats.')
    parser.add_argument('--qa_transform', type=str,
                        help='Deformation from subject to mni152 space to '
                        'perform QA analysis on output segmentation. Can '
                        'either be a path to a single image or a suffix (will '
                        'search output directory for matching filename).')
    parser.add_argument('--write_vol_stats', action='store_true',
                        help='Flag to calculate and save label volumes')
    parser.add_argument('--etiv', action='store_true',
                        help='Flag to include eTIV in volume stats; default '
                        'is true in FS mode and false in normal mode. If set '
                        'to true in normal mode, it will significantly '
                        'increase run time.')
    parser.add_argument('--tal', type=str,
                        help='Alternative talairach xfm transform for '
                        'estimating eTIV. Can either be a path to a single '
                        'transform or a suffix (will search output directory '
                        'for matching filename).')
    
    return parser



#------------------------------------------------------------------------------
#                          Custom pglands segmenter
#------------------------------------------------------------------------------

class PGlandsSegmenter:
    """
    Main class that performs segmentation of pituitary+pineal glands for a set 
    of input MRIs. All required filenames should be stored in the paths_dict 
    input (created with the get_filenames() function). The general workflow 
    to segment a single image is:
    1. Parse all filenames
       - General I/O: input T1 (invol), output segmentation (outvol) 
       - Output stats files: label volumes (volstats), QA scores (qastats)
       - Required transforms: affine from mni to subject, warp from mni to 
         subject, and affine talairach for eTIV
    2. Preprocessing
       - Register (affine) mni_label_template to subject space (if not using 
         center_crop)
       - Crop volume to smaller patch window (centered on label template or on 
         input volume)
       - Robust (95%) intensity normalization
    3. Segmentation
       - Process image w/ trained model to yield posteriors
       - Generate hard segmentation from softmax posteriors
    4. Volume calculations
       - Calculate number of voxels and total volume of each label in the 
         output segmentation
    5. QA analysis
       - Deformably register output segmentation to mni152 space
       - Calculate pairwise dice between mni152 segmentation and QA dataset 
         (the set of all ground truth labels in mni152 space)
       - Take the maximum Dice for each label
    """
    def __init__(self,
                 paths_dict:dict,            # dictionary with filenames
                 model_path:str,             # path to trained model
                 lut_path:str=None,          # path to label look-up table
                 template_path:str=None,     # path to mni152 label template
                 template_aff:str=None,      # path to affine from mni152->subj
                 crop_size:int=96,           # isotropic crop window size
                 use_center_crop:bool=True,  # use img center for crop (no mni)
                 norm_percent:float=0.95,    # % for robust normalization
                 device:str=None,            # run on GPU (cuda) or CPU
                 include_etiv:bool=False,    # flag to include etiv in stats
                 write_posts:bool=False,     # flag to write posterior imgs
                 write_vols:bool=False,      # flag to write volume stats
                 write_qas:bool=False,       # flag to write qa stats
    ):
        # Parse inputs
        self.paths_dict = paths_dict
        self.crop_patch_size = crop_size
        self.use_center_crop = use_center_crop
        self.include_etiv = include_etiv
        self.norm_percent = norm_percent
        self.device = device

        self.write_posteriors = write_posts
        self.write_vol_stats = write_vols
        self.write_qa_stats = write_qas
        module_dir = os.path.join(
            os.environ.get('FREESURFER_HOME_FSPYTHON'), 'models/pglands_seg'
        )

        # Data utilities
        default_lut = os.path.join(module_dir,'pglands.ctab')
        self.lut = sf.load_label_lookup(
            lut_path if lut_path is not None else default_lut
        )
        default_template = os.path.join(module_dir,'mni152_label_template.mgz')
        self.label_template_path = template_path if template_path is not None \
            else default_template

        # Set up trained model
        default_model = os.path.join(module_dir,'pglands_seg.pth')
        model_path = default_model if model_path is None else model_path
        model_checkpoint = torch.load(model_path, map_location=self.device)

        self.model = UNet3D(in_channels=1,
                            out_channels=len(self.lut)
        ).to(self.device)
        self.model.load_state_dict(model_checkpoint['model_state'])
        
        # Initialize QA dataset (if necessary)
        self.qa_dataset = self.QAdataset() if write_qas else None            
        
    class QAdataset(Dataset):
        """
        Data loader for training data used in QA analysis
        """
        def __init__(self):
            """
            Will pull list of qa images from default dir
            """
            qa_dir = os.path.join(os.environ.get('FREESURFER_HOME_FSPYTHON'),
                                  'models/pglands_seg/pglands_seg_qa_data')
            default_qa_paths = [os.path.join(qa_dir, x)
                                for x in os.listdir(qa_dir)]
            self.imglist = default_qa_paths

        def __len__(self):
            return len(self.imglist)

        def __getitem__(self, idx):
            img = load_volume(self.imglist[idx], conform=False, is_int=True)[0]
            return img, idx


    def _compute_label_volumes(self, img:sf.Volume, etiv=None):
        """
        Calculate volumes of each label
        """
        voxvol = np.prod(img.geom.voxsize)
        n_voxs = [(img.data == val).sum() for val in self.lut]
        vols = [voxvol * n for n in n_voxs]
        return vols, n_voxs
    

    def _compute_etiv(self, lta:str, invol:str=None, outdir:str=None):
        """
        Estimate eTIV from input lta talairach transform. If lta does not 
        exist, calculate from scratch (will increase run time).
        """
        if not os.path.isfile(lta):
            print(f'Computing talairach transform ({lta}) to measure eTIV for '
                  'volume stats analysis (this should take about 3 minutes)')
            
            # Make temporary files
            tmpdir = os.path.join(os.path.dirname(lta), 'temp')
            norm = os.path.join(tmpdir, 'nu.mgz')
            xfm = os.path.join(tmpdir, 'talairach.xfm')
            os.makedirs(tmpdir, exist_ok=True)

            # Conform the image
            cmd = f'mri_convert --conform {invol} {norm} > /dev/null'
            print(cmd[:-12])
            if sf.system.run(cmd) != 0:
                sfs_fatal('mri_convert --conform failed!')

            # Run intensity normalization
            cmd = f'mri_nu_correct.mni --no-rescale --i {norm} --o {norm} ' +\
                f'--n 1 --proto-iters 1000 --distance 50 --ants-n4 > /dev/null'
            print(cmd[:-12])
            if sf.system.run(cmd) != 0:
                sf.system.fatal('mri_nu_correct failed!')

            # run talairach registration
            cmd = f'talairach_avi --i {norm} --xfm {xfm} > /dev/null'
            print(cmd[:-12])
            if sf.system.run(cmd) != 0:
                sf.system.fatal('talairach_avi failed!')

            # convert XFM to LTA
            mni305 = os.path.join(os.environ.get('FREESURFER_HOME'),
                                  'average', 'mni305.cor.mgz')
            cmd = f'lta_convert --src {norm} --trg {mni305} --inxfm {xfm} ' +\
                f'--outlta {lta} --subject fsaverage --ltavox2vox > /dev/null'
            print(cmd[:-12])
            if sf.system.run(cmd) != 0:
                sf.system.fatal('lta_convert failed!')                
            shutil.rmtree(tmpdir)

        # Calculate eTIV
        scale_factor = 1948.106
        return 1e3 * scale_factor / sf.load_affine(lta).det()
            

    def _get_subject_label_template(self, inpath:str, trpath:str, outdir:str):
        """
        Register and load mni152 label template in subject space
        """
        # Set up directory
        regdir = os.path.dirname(trpath)
        os.makedirs(regdir, exist_ok=True)
        
        # Calculate affine transform
        if not os.path.isfile(trpath):
            cmd = f'fs-synthmorph-reg --i {inpath} --o {regdir} ' +\
                '--mni-out-res 1.0mm --mni-targ-res 1.0mm ' +\
                '--no-crop --pituitary --affine-only > /dev/null'
            print(cmd[:-12])
            if sf.system.run(cmd) != 0:
                sfs_fatal('fs-synthmorph-reg failure :(')

        # Apply transform
        outbase, _ = split_extension(os.path.basename(inpath))
        tmpbase, ext = split_extension(
            os.path.basename(self.label_template_path)
        )
        outpath = os.path.join(outdir, '.'.join([tmpbase, outbase, ext[1:]]))
        if not os.path.isfile(outpath):
            cmd = f'mri_synthmorph apply -m nearest {trpath} ' +\
                f'{self.label_template_path} {outpath} > /dev/null'
            print(cmd[:-12])
            if sf.system.run(cmd) != 0:
                sfs_fatal('mri_synthmorph apply failure :(')

        return load_volume(outpath, is_int=True)[0]

    

    def _preprocess(self, x, template=None):
        """
        Perform pre-processing steps (crop + intensity norm)
        """
        x, crop_bounds = crop_volume(x, template, self.crop_patch_size)
        x = robust_norm(x, self.norm_percent) if self.norm_percent < 1. \
            else norm_minmax(x)
        return x.unsqueeze(0).unsqueeze(0), crop_bounds
        

    def _predict(self, x):
        """
        Predict segmentations w/ trained model
        """
        with torch.no_grad():
            logits = self.model(x)
        posteriors = torch.nn.functional.softmax(logits, dim=1)        
        return posteriors.squeeze()


    def _qa_analysis(self, movpath:str, segpath:str, trpath:str):
        """
        Perform the QA analysis on output segmentation
        """
        # Set up registration dir
        regdir = os.path.dirname(trpath)
        os.makedirs(regdir, exist_ok=True)

        # Calculate deformation to mni152 space
        if not os.path.isfile(trpath):
            print('Calculating deformable warp from '
                  f'{os.path.basename(movpath)} to mni152 space for QA '
                  'analysis (this should take about 15 mins)')
            regdir = os.path.dirname(trpath)
            cmd = f'fs-synthmorph-reg --i {movpath} --o {regdir} ' +\
                '--mni-out-res 1.0mm --mni-targ-res 1.0mm ' +\
                '--no-crop --pituitary > /dev/null'
            print(cmd[:-12])
            if sf.system.run(cmd) != 0:
                sfs_fatal('fs-synthmorph-reg failure :(')
                
        # Apply warp
        fbase, ext = split_extension(segpath)
        mnisegpath = '.'.join([fbase, 'mni152', ext[1:]])
        if not os.path.isfile(mnisegpath):
            cmd = f'mri_synthmorph apply -m nearest {trpath} ' +\
                f'{segpath} {mnisegpath} > /dev/null'
            print(cmd[:-12])
            if sf.system.run(cmd) != 0:
                sfs_fatal('mri_synthmorph apply failure :(')

        # Calculate pairwise dice between output seg and all QA images
        x = load_volume(mnisegpath, conform=False, is_int=True)[0]
        qa_dice = np.zeros((len(self.qa_dataset), len(self.lut)))
        qa_dice[:] = np.nan
        for y, idx in self.qa_dataset:
            qa_dice[idx, :] = [2 * ((x==n) * (y==n)).sum()
                               / ((x==n).sum() + (y==n).sum())
                               for n in self.lut
            ]
        return np.max(qa_dice, axis=0)


    def _write_stats_file(self, data:list, fspecs:list, fname:str,
                          colnames:Optional[List[str]]=None, header:str=None,
                          etiv:float=None):
        """
        Write stats file in the FreeSurfer formatting convention
        """
        if colnames is not None:
            colnames = [colnames] if isinstance(colnames, str) else colnames

        with open(fname, 'w') as f:
            if header is not None:
                f.write(f'# {header}\n')
            f.write('# Created by mri_pglands_seg\n')
            if etiv is not None:
                f.write('# Measure EstimatedTotalIntraCranialVol, eTIV, '
                        'Estimated Total Intracranial Volume,'
                        f'{etiv:.3f}, mm^3\n')

            f.write(f'# NRows {len(self.lut)}\n')
            f.write('# NTableCols 5\n')
            f.write('# ColHeaders Index SegId')
            for name in colnames:
                f.write(f' {name}')
            f.write(' StructName\n')

            for n, val in enumerate(self.lut):
                f.write(f'{n:>4} {val:>6}')
                for j in range(len(data)):
                    fstr = fspecs[j].format(data[j][n])
                    f.write(' '+fstr)
                f.write(f' {self.lut[val].name}\n')
        
        
    def __call__(self, idx):
        """
        Main function for segmenter
        """
        # Grab filenames
        invol = self.paths_dict['input_volumes'][idx]
        outvol = self.paths_dict['output_segmentations'][idx]
        
        mniaff = self.paths_dict['mni_template_transforms'][idx] \
            if not self.use_center_crop else None
        
        volstats = self.paths_dict['output_vol_stats'][idx] \
            if self.write_vol_stats else None
        talaff = self.paths_dict['etiv_transforms'][idx] \
            if self.write_vol_stats and self.include_etiv else None
        
        qastats = self.paths_dict['output_qa_stats'][idx] \
            if self.write_qa_stats else None
        qawarp = self.paths_dict['qawarps'][idx] \
            if self.write_qa_stats else None
        

        # Load input image
        print(f'Performing pglands segmentation of {invol}')
        X, input_geom, conform_geom = load_volume(invol)

        # Load template (if necessary)
        template = None if mniaff is None \
            else self._get_subject_label_template(invol, mniaff,
                                                  os.path.dirname(outvol))
        
        # Run model
        X, crop_bounds = self._preprocess(
            X.to(self.device),
            template=None if template is None else template.to(self.device)
        )
        posteriors = self._predict(X)
        onehot = torch.argmax(posteriors, dim=0).squeeze()
        seg = largest_connected_component(
            onehot_to_labels(
                torch.argmax(
                    posteriors, dim=0).squeeze(), self.lut
            )
        )

        # Output final segmentation
        print(f'Segmentation complete! saving output as {outvol}')
        seg = save_volume(img=seg, path=outvol, crop_bounds=crop_bounds,
                          input_geom=input_geom, conform_geom=conform_geom,
                          is_labels=True, label_lut=self.lut,
                          return_output=True)

        # Output posteriors
        if self.write_posteriors:
            fbase, ext = split_extension(outvol)
            print(f'Saving posteriors as {fbase}.posterior.<label_name>.mgz')
            
            for n, val in enumerate(self.lut):
                label_name = self.lut[val].name
                if label_name != 'unknown':
                    path = '.'.join([fbase, 'posterior', label_name, ext[1:]])
                    save_volume(img=y[n], path=path,
                                crop_bounds=crop_bounds,
                                input_geom=input_geom,
                                conform_geom=conform_geom)
                    
        # Output volume stats
        if self.write_vol_stats:
            etiv = self._compute_etiv(talaff, invol) if self.include_etiv \
                else None
            print(f'Computing label stats (saving as {volstats}')
            vols, n_voxs = self._compute_label_volumes(seg, etiv)
            self._write_stats_file(data=[vols, n_voxs],
                                   fname=volstats,
                                   colnames=['NVoxels','Volume_mm3'],
                                   fspecs=['{:>12.0f}', '{:>12.3f}'],
                                   header='PGlands Volume Stats',
                                   etiv=etiv)
        # Output QA stats
        if self.write_qa_stats:
            print(f'Computing QA stats (saving as {qastats})')
            qa_scores = self._qa_analysis(movpath=invol,
                                          segpath=outvol,
                                          trpath=qawarp)
            self._write_stats_file(data=[qa_scores],
                                   fname=qastats,
                                   colnames=['QAscore'],
                                   fspecs=['{:>6.4f}'],
                                   header='PGlands QA Scores')

        

#-----------------------------------------------------------------------------#
#                              Image data utils                               #
#-----------------------------------------------------------------------------#

###
class QAdataset(Dataset):
    """
    Data loader for training data used in QA analysis
    """
    def __init__(self,
                 imglist_path:str=None,
                 lut_path=None,
                 device:str=None
    ):
        self.device = device
        
        # Get list of images
        imglist_path = 'qa_image_filenames.txt' if imglist_path is None \
            else imglist_path

        f = open(imglist_path, 'r')
        self.imglist = f.readlines()
        self.imglist = [x[:-1] for x in self.imglist]
        f.close()

        # Load label lut
        lut_path = 'pglands.ctab' if lut_path is None else lut_path
        self.lut = sf.load_label_lookup(lut_path)

    def __len__(self):
        return len(self.imglist)

    def __getitem__(self, idx):
        img = load_volume(self.imglist[idx], conform=False)[0]
        return img, idx
    

#------------------------------------------------------------------------------
# I/O functions

###
def get_filenames(mode:str='normal',              # FS or normal
                  FS_subject_ids:List[str]=None,  # list of FS subject ids (FS)
                  FS_subjects_dir:str=None,       # FS subjects dir (FS)
                  FS_basename:str='T1',           # basename for files (FS)
                  inpath:List[str]=None,          # input image(s)/dir (normal)
                  outpath:List[str]=None,         # output image/dir (normal)
                  outbase:str=None,               # filename base for outputs
                  talaff:str=None,               # .lta for talairach reg
                  mniaff:str=None,               # affine for mni-->subj reg
                  qawarp:str=None,               # warp for subj-->mni reg
):
    """
    Utility function to generate a dictionary of all input/output filenames
    based on commandline arguments. These include:
    - General I/O: input T1s, output segmentations
    - Output stats files: label volumes, QA scores
    - Required transforms: affines from mni to subject, warps from mni to
      subject, and affine talairachs for eTIV
    """
    def _is_basename(inpath:str):
        return True if inpath == os.path.basename(inpath) else False
    
    exts = ['.mgh', '.mgz', '.nii', '.nii.gz']
    talaff = 'talairach.xfm.lta' if talaff is None else talaff
    mniaff = 'reg.targ_to_invol.lta' if mniaff is None else mniaff
    qawarp = 'warp.to.mni152.1.0mm.1.0mm.nii.gz' if qawarp is None else qawarp
    
    ##
    if mode == 'FS':
        # Set SUBJECTS_DIR
        sdir = os.getenv('SUBJECTS_DIR') if FS_subjects_dir is None \
            else os.path.abspath(FS_subjects_dir)

        if sdir is None:
            sfs_fatal('Must set subjects directory with --sd or SUBJECTS_DIR '
                      'env variable.')

        summary_file_prefix = os.path.join(sdir, outbase + '_')
        print(f'Using subject directory {sdir}')

        # If no subjects provided, search sdir
        subjects = FS_subject_ids
        invols = None
        inbase = FS_basename
        
        if subjects is None or len(subjects) == 0:
            invols = glob.glob(f'{sdir}/*/mri/{inbase}.mgz')
            subjects = [os.path.basename(x.replace(f'/mri/{inbase}.mgz', ''))
                        for x in invols]

        if len(subjects) == 0:
            sfs_fatal(f'Subjects directory {sdir} does not contain any '
                      'valid recon-all subjects.')
            
        # Get list of input T1 images (if not already done)
        invols = invols if invols is not None else \
            [os.path.join(sdir, x, f'mri/{inbase}.mgz') for x in subjects]

        # Configure output directory
        if isinstance(outpath, list) and len(outpath) > 1:
            sfs_fatal('Cannot provide multiple outpaths (--o) if running '
                      'in FreeSurfer mode')

        outpath = sdir if outpath is None else outpath[0]

        if is_valid_filename(outpath, exts):
            sfs_fatal('Output (--o) must be a directory (or left blank) if '
                      'running in FreeSurfer mode')
                              
        # Set list of all output filenames for each subject
        n_images = len(invols)
        outbase = '.'.join([inbase, outbase])
        
        outvols = [None] * n_images
        posts = [None] * n_images
        volstats = [None] * n_images
        qastats = [None] * n_images
        talaffs = [None] * n_images
        mniaffs = [None] * n_images
        mni_warps = [None] * n_images
        
        for i, subject in enumerate(subjects):
            # Subdirs
            subject_dir = os.path.join(sdir, subject)
            mri_dir = os.path.join(subject_dir, 'mri') \
                if outpath == sdir else subject_dir
            stats_dir = os.path.join(subject_dir, 'stats') \
                if outpath == sdir else subject_dir
            transform_dir = os.path.join(mri_dir , 'transforms') \
                if outpath == sdir else subject_dir
            mni_dir = os.path.join(transform_dir, 'mni152')
            
            # Outputs
            outvols[i] = os.path.join(mri_dir, f'{outbase}.mgz')
            volstats[i] = os.path.join(stats_dir, f'{outbase}.volumes.stats')
            qastats[i] = os.path.join(stats_dir, f'{outbase}.qa.stats')
            
            # Transforms
            talaffs[i] = talaff if not _is_basename(talaff) \
                else os.path.join(transform_dir, talaff)
            mniaffs[i] = mniaff if not _is_basename(mniaff) \
                else os.path.join(mni_dir, mniaff)
            mni_warps[i] = qawarp if not _is_basename(qawarp) \
                else os.path.join(mni_dir, qawarp)

            
    elif mode == 'normal':
        # Parse input path(s)
        invols = []
        
        inisfiles = np.all(np.array([is_filename(pth) for pth in inpath]))
        if len(inpath) > 1 and not inisfiles:
            sfs_fatal('If input path list (--i) contains multiple entries, '
                      'all must be to files (not directories).')
        if inisfiles:
            # Input is filename(s)
            for pth in inpath:
                if not os.path.isfile(pth):
                    sfs_fatal(f'{pth} is not an existing directory or file')
                if not is_valid_filename(pth, exts):
                    sfs_fatal(f'Input {pth} is not a supported file type '
                              '(must be .mgz, .mgh, .nii, or .nii.gz)')
                invols += [pth]
        else:
            # Input is dir -> find all valid filenames
            vols = []
            for ext in exts:
                vols += sorted(glob.glob(os.path.join(inpath[0], '*' + ext)))
            if len(vols) < 1:
                print(f'No valid input files found in {inpath[0]}')
            invols += vols
                
        if len(invols) == 0:
            sfs_fatal('Found no existing directories or files in provided '
                      f'inputs ({" ".join(inpath)})')

        # Parse output path(s)
        outvols = []
        if outpath is None:
            # No outpath provided, parse from input filenames
            print('Output path(s) (--o) not provided... parsing from input '
                  'path(s).')
            for pth in invols:
                fdir, ext = split_extension(pth, exts)
                outvols += [os.path.join(
                    fdir, '.'.join([os.path.basename(fdir), outbase, ext[1:]])
                )]
        else:
            # Output paths provided -> check compatability
            outisfiles = np.all(np.array([is_filename(pth) \
                                          for pth in outpath]))
            if len(outpath) > 1 and not outisfiles:
                sfs_fatal('If output path list (--o) contains multiple '
                          'entries, all must be to valid file names.')
            if outisfiles and len(inpath) != len(outpath):
                sfs_fatal('If providing multiple output file paths (--o), the '
                          'number of output paths must equal the number of '
                          'input paths.')
            if not inisfiles and outisfiles:
                sfs_fatal('Cannot provide directory as input (--i) and '
                          'filename as output (--o).')
                
            if outisfiles:
                # Output is filename(s)
                for pth in outpath:
                    if not is_valid_filename(pth, exts):
                        sfs_fatal(f'Output {pth} is not a supported file type '
                                  '(must be .mgz, .mgh, .nii, or .nii.gz)')
                outvols += [pth]
            else:
                # Output is dir -> make list of filenames from invols
                for inname in invols:
                    fbase = os.path.basename(inname)
                    fbase, ext = split_extension(fbase, exts)
                    outname = '.'.join([fbase, outbase, ext[1:]])
                    outvols += [os.path.join(outpath[0], fbase, outname)]

        # Optional inputs
        n_images = len(outvols)

        posts = [None] * n_images
        volstats = [None] * n_images
        qastats = [None] * n_images
        talaffs = [None] * n_images
        mniaffs = [None] * n_images
        mni_warps = [None] * n_images

        for i, fname in enumerate(outvols):
            fbase, ext = split_extension(fname, exts)            

            # Outputs
            posts[i] = '.'.join([fbase, 'posterior', ext[1:]])
            volstats[i] = '.'.join([fbase, 'volumes.stats'])
            qastats[i] = '.'.join([fbase, 'qa.stats'])

            # Transforms
            outdir = os.path.dirname(fbase)
            taldir = os.path.join(outdir, 'transforms')
            mnidir = os.path.join(outdir, 'transforms/mni152')

            talaffs[i] = talaff if not _is_basename(talaff) \
                else os.path.join(taldir, talaff)
            mniaffs[i] = mniaff if not _is_basename(mniaff) \
                else os.path.join(mnidir, mniaff)
            mni_warps[i] = qawarp if not _is_basename(qawarp) \
                else os.path.join(mnidir, qawarp)

    ## Store paths in single dictionary
    fpaths = {'input_volumes': invols}
    fpaths['output_segmentations'] = outvols
    fpaths['output_posteriors'] = posts
    fpaths['output_vol_stats'] = volstats
    fpaths['output_qa_stats'] = qastats
    fpaths['etiv_transforms'] = talaffs
    fpaths['mni_template_transforms'] = mniaffs
    fpaths['qawarps'] = mni_warps
    
    return fpaths


###
def load_volume(path:str,             # Path to load
                shape=(256,256,256),  # Output image dimensions
                voxsize=1.0,          # Output image resolution
                orientation='RAS',    # Output image orientation
                is_int:bool=False,    # Flag if image is int or float
                conform:bool=True,    # Flag to conform image
                to_tensor:bool=True,  # Flag to convert sf.Volume to tensor
):
    """
    Loads an input volume (using surfa) and conforms to a specific geometry (if
    conform=True). Returns the image as a tensor (if to_tensor=True) along with
    the original and conformed geometries.
    """
    # Load
    img = sf.load_volume(path)
    geom = img.geom

    # Conform
    img = img.conform(shape=shape if conform else None,
                      voxsize=voxsize if conform else None,
                      orientation=orientation if conform else None,
                      dtype=np.int32 if is_int else np.float32,
                      method='nearest' if is_int else 'linear'
    )
    x = torch.Tensor(img.data).to(torch.int if is_int else torch.float) \
        if to_tensor else img
    return x, geom, img.geom


###
def save_volume(img,                      # image data
                path,                     # path to save image
                input_geom=None,          # input image geometry
                conform_geom=None,        # conformed image geometry
                crop_bounds=None,         # bounds of data cropping
                label_lut=None,           # lut associated w/ image
                is_labels:bool=False,     # flag if output is label image
                return_output:bool=False  # flag to return conformed output
):
    """
    Saves an output image with the option to first conform the image to its
    original geometry. This requires both the conform_geom and the input_geom.
    Also has the option to return the conformed output image.
    """
    # Reform image to original size/geometry
    img = img.cpu().numpy() if torch.is_tensor(img) else img
    img = img.astype(np.int32 if is_labels else np.float32)
    img = pad_volume(img, crop_bounds) if crop_bounds is not None else img
    img = sf.Volume(img, geometry=conform_geom)

    if input_geom is not None:
        img = img.conform(shape=input_geom.shape,
                          voxsize=input_geom.voxsize,
                          orientation=input_geom.orientation,
                          method='nearest' if is_labels else 'linear'
        )

    # Write image
    if label_lut is not None: img.labels = label_lut
    img.save(path)

    return img if return_output else None


#------------------------------------------------------------------------------
# Image manipulation functions

###
def crop_volume(x,                       # input data
                crop_template=None,      # template to set crop window
                crop_sz:List[int]=None,  # size of patch to extract
):
    """
    Crops an input tensor x, centered around crop_template. If crop_template
    is none, the crop window will center around the input x.
    """
    full_sz = x.shape
    crop_sz = [crop_sz] * len(full_sz) if isinstance(crop_sz, int) else crop_sz

    if crop_template is not None:
        # Check if image/template sizes are equal
        if crop_template.shape != full_sz:
            sfs_fatal('Input image and template must have equal sizes')

        # Get center of bounding box from labels template
        bbox = [[0, full_sz[i]-1] for i in range(len(full_sz))]
        while crop_template[bbox[0][0],:,:].sum() == 0: bbox[0][0] += 1
        while crop_template[bbox[0][1],:,:].sum() == 0: bbox[0][1] -= 1
        while crop_template[:,bbox[1][0],:].sum() == 0: bbox[1][0] += 1
        while crop_template[:,bbox[1][1],:].sum() == 0: bbox[1][1] -= 1
        while crop_template[:,:,bbox[2][0]].sum() == 0: bbox[2][0] += 1
        while crop_template[:,:,bbox[2][1]].sum() == 0: bbox[2][1] -= 1

        center = [(bb[0] + bb[1])//2 for bb in bbox]

    else:
        # Center crop patch in input image
        center = [vs//2 for vs in full_sz]

    # Get bounds centered around bbox center
    bounds = [[c - ps//2, c + ps//2] for c, ps in zip(center, crop_sz)]

    # Make sure indices do not extend outside input image size
    for i in range(len(bounds)):
        if bounds[i][0] < 0:
            shift = bounds[i][0]
        elif bounds[i][1] > full_sz[i]:
            shift = bounds[i][1] - full_sz[i]
        else:
            shift = 0
        bounds[i] = [bounds[i][0] - shift, bounds[i][1] - shift]

    # Crop the input image
    h, w, d = bounds
    x_crop = x[..., h[0]:h[1], w[0]:w[1], d[0]:d[1]]

    return x_crop, bounds


###
def largest_connected_component(x, vals=None, bgval=0):
    """
    Extracts the largest connected components for each foreground label in a 
    multi-label image
    """
    x = x.cpu().numpy() if torch.is_tensor(x) else x
    vals = np.unique(x) if vals is None else vals
    vals = [i for i in vals if i != bgval]
    x_cc = np.tile(np.zeros(x.shape), (len(vals)+1,1,1,1))

    for j, val in enumerate(vals):
        x_j = np.squeeze(np.where(x==val, 1, 0))
        x_j_cc, n_cc = ndimage.label(x_j, np.ones((3,3,3)))
        
        if n_cc == 1:
            largest_cc_val = 1
        else:
            cc_vals = np.unique(x_j_cc)[1:]
            cc_counts = np.array([(x_j_cc==i).sum() for i in cc_vals])
            try:
                largest_cc_val = cc_vals[cc_counts==cc_counts.max()].item()
            except:
                largest_cc_val = cc_vals[
                    np.array(cc_counts==cc_counts.max(), dtype=int)[0]
                ].item()
        x_cc[j+1, ...] = np.where(x_j_cc==largest_cc_val, val, 0)
        
    return np.sum(x_cc, axis=0, dtype=x.dtype)



def onehot_to_labels(onehot,  # input one-hot encoded tensor
                     lut,     # surfa label lookup table
):
    """
    Convert onehot encoded predictions back to label map
    """
    x = onehot.clone()
    for i, val in enumerate(lut):
        x[x==i] = val
    return x


def pad_volume(img,                         # input cropped volume
               crop_window,                 # crop bounds from full image
               full_shape=(256, 256, 256),  # shape of full padded image
):
    pad_width = [[cw[0], fs - cw[1]] \
                 for cw, fs in zip(crop_window, full_shape)]
    return np.pad(img, pad_width=pad_width)

            
def robust_norm(x,                   # input tensor
                m:float=0.,          # minimum intensity value
                M:float=1.,          # maximum intensity value
                min_perc:float=0.,   # minimum % to clip intensities
                max_perc:float=0.95  # maximum % to clip intensities
):
    # Get sizes
    full_sz = x.shape
    n_vox = prod([sz for sz in full_sz])
    
    # Convert percentages to intensities
    x_sorted, _ = x.reshape((n_vox,)).sort()
    min_val = x_sorted[max(int(min_perc * n_vox), 0)]
    max_val = x_sorted[min(int(max_perc * n_vox), n_vox - 1)]
    
    # Robust normalization
    x = torch.clamp(x, min=min_val, max=max_val)
    return (M - m) * (x - x.min()) / (x.max() - x.min()) + m



#------------------------------------------------------------------------------
# Small image filename utilities

###
def is_filename(fname):
    """
    Checks if input path is a filename
    """
    if '.' in fname[-5:]:
        return True
    return False

###
def is_valid_filename(fname, exts=['.mgh', '.mgz', '.nii', '.nii.gz']):
    """
    Checks if input filename is a supported file type
    """
    for ext in exts:
        if fname.endswith(ext):
            return True
    return False

###
def split_extension(fname:str, exts=['.mgh', '.mgz', '.nii', '.nii.gz']):
    """
    Splits the extension from filename (returns both)
    """
    for ext in exts:
        if fname.endswith(ext):
            return (fname[:-len(ext)], ext)
    sfs_fatal(f'{fname} is an unsupported image file type.')


#-----------------------------------------------------------------------------#
#                                 Model utils                                 #
#-----------------------------------------------------------------------------#

###
class UNet3D(nn.Module):
    def __init__(self,
                 in_channels:int,                    # no. of input channels
                 out_channels:int,                   # no. out output features
                 conv_sz:int=3,                      # conv window size
                 pool_sz:int=2,                      # pooling window size
                 n_convs_per_block:int=2,            # no. of layers per block
                 n_levels:int=4,                     # no. of enc/dec levels
                 n_starting_features:int=24,         # no. of starting features
                 normalization_type:str='Instance',  # normalization func.
                 activation_function:str='ELU',      # activation func.
                 pooling_type:str='MaxPool',         # pooling func.
                 residuals:bool=False,               # flag for residual conns.
                 skip:bool=True,                     # flag for skip conns.
                 X:int=3,                            # no. of spatial dims.
    ):

        """
        Main UNet class.
        """
        super(UNet3D, self).__init__()

        self.conv_sz = conv_sz
        self.pool_sz = pool_sz

        self.block_config = [n_starting_features * (2**i) \
                             for i in range(n_levels)]
        self.n_blocks = len(self.block_config)
        self.skip = skip
        self.residuals = residuals

        self.activation_function = activation_function \
            if callable(getattr(torch.nn, activation_function)) \
               else Exception('Invalid activation_function (not an ' +\
                              'attribute of torch.nn')

        # Encoding blocks:
        self.encoding = nn.Sequential()
        encoding_config = [in_channels] + self.block_config

        for b in range(len(encoding_config)-1):
            block = _UNetBlock(n_input_features=encoding_config[b],
                               n_output_features=encoding_config[b+1],
                               n_layers=n_convs_per_block,
                               conv_sz=self.conv_sz,
                               norm_type=normalization_type,
                               activ_fn=activation_function,
                               level=b,
                               residuals=residuals,
                               drop=0,
                               X=X,
            )
            self.encoding.add_module('EncodeBlock%d' % (b+1), block)
            if b != n_levels:
                pool = eval('nn.MaxPool%dd' % X)(kernel_size=pool_sz,
                                                 stride=pool_sz,
                                                 return_indices=True
                )
                self.encoding.add_module('Pool%d' % (b+1), pool)

        # Decoding blocks:
        self.decoding = nn.Sequential()
        decoding_config = [out_channels] + self.block_config

        for b in reversed(range(len(decoding_config)-1)):
            if b != n_levels:
                unpool = eval('nn.MaxUnpool%dd' % X)(kernel_size=pool_sz,
                                                     stride=pool_sz)
                self.decoding.add_module('Unpool%d' % (b+1), unpool)

            block = _UNetBlockTranspose(n_input_features=decoding_config[b+1],
                                        n_output_features=decoding_config[b],
                                        n_layers=n_convs_per_block,
                                        conv_sz=conv_sz,
                                        norm_type=normalization_type,
                                        activ_fn=activation_function,
                                        level=b,
                                        residuals=residuals,
                                        drop=0,
                                        X=X,
            )
            self.decoding.add_module('DecodeBlock%d' % (b+1), block)

    ##
    def forward(self, x):
        enc = [None] * (self.n_blocks + 1) # encoding
        dec = [None] * (self.n_blocks) # decoding
        idx = [None] * (self.n_blocks) # maxpool indices
        siz = [None] * (self.n_blocks) # maxunpool output size

        # Encoding
        enc[0] = x
        for b in range(0, self.n_blocks):
            x = enc[b+1] = \
                self.encoding.__getattr__('EncodeBlock%d' % (b+1))(x)
            siz[b] = x.shape
            if b != self.n_blocks - 1:
                x, idx[b] =  self.encoding.__getattr__('Pool%d' % (b+1))(x)

        # Decoding
        for b in reversed(range(self.n_blocks)):
            if b != self.n_blocks - 1:
                x = self.decoding.__getattr__('Unpool%d' % (b+1))\
                    (x, idx[b], output_size=siz[b])
            x = dec[b] = self.decoding.__getattr__('DecodeBlock%d' % (b+1))\
                (torch.cat([x, enc[b+1]], 1))

        return x


#------------------------------------------------------------------------------
# UNet layer classes

###
class _UNetLayer(nn.Module):
    def __init__(self,
                 X:int,
                 n_input_features:int,
                 n_output_features:int,
                 conv_sz:int=3,
                 norm_type:str=None,
                 activ_fn:str=None,
                 drop_rate:float=0.,
                 **kwargs
    ):
        """
        Encoding layer (called by _UNetBlock)
        """
        super(_UNetLayer, self).__init__()
        pad_sz = (conv_sz-1)//2 if conv_sz % 2 == 1 else (conv_sz//2)
        conv_bias = False if norm_type is not None else True
        
        norm = eval('nn.%sNorm%dd' % (norm_type, X))(n_input_features)
        activ = eval('nn.%s' % activ_fn)() if activ_fn is not None \
            else nn.Identity()
        conv = eval('nn.Conv%dd' % X)(n_input_features,
                                        n_output_features,
                                        kernel_size=conv_sz,
                                        padding=pad_sz,
                                        bias=conv_bias
        )
        drop = eval('nn.Dropout%dd' % X)(p=drop_rate)

        self.add_module('norm', norm if norm_type is not None \
                        else nn.Identity())
        self.add_module('activ', activ if activ_fn is not None \
                        else nn.Identity())
        self.add_module('conv', conv)
        self.add_module('drop', drop if drop_rate > 0 else nn.Identity())
        
    def forward(self, x):
        return self.drop(self.conv(self.activ(self.norm(x))))


###
class _UNetLayerTranspose(nn.Module):
    def __init__(self,
                 X:int,
                 n_input_features:int,
                 n_output_features:int,
                 conv_sz:int=3,
                 norm_type:str=None,
                 activ_fn:str=None,
                 drop_rate:float=0.,
                 **kwargs
    ):
        """
        Decoding layer (called by _UNetBlockTranspose)
        """        
        super(_UNetLayerTranspose, self).__init__()
        pad_sz = (conv_sz-1)//2 if conv_sz % 2 == 1 else (conv_sz//2)
        conv_bias = False if norm_type is not None else True

        drop = eval('nn.Dropout%dd' % X)(p=drop_rate)
        norm = eval('nn.%sNorm%dd' % (norm_type, X))(n_input_features)
        activ = eval('nn.%s' % activ_fn)() if activ_fn is not None \
            else nn.Identity()
        conv = eval('nn.ConvTranspose%dd' % X)(n_input_features,
                                               n_output_features,
                                               kernel_size=conv_sz,
                                               padding=pad_sz,
                                               bias=conv_bias
        )

        self.add_module('drop', drop if drop_rate > 0 else nn.Identity())
        self.add_module('norm', norm if norm_type is not None \
                        else nn.Identity())
        self.add_module('activ', activ if activ_fn is not None \
                        else nn.Identity())
        self.add_module('conv', conv)

    def forward(self, x):
        return self.conv(self.activ(self.norm(self.drop(x))))


#------------------------------------------------------------------------------
# UNet block classes

###
class _UNetBlock(nn.ModuleDict):
    def __init__(self,
                 X:int,
                 n_input_features:int,
                 n_output_features:int,
                 n_layers:int,
                 conv_sz:int,
                 norm_type:str,
                 activ_fn:str,
                 level:int,
                 residuals:bool=False,
                 drop=0,
                 skip=False,
                 **kwargs
    ):
        """
        Encoding UNet block
        """
        super(_UNetBlock, self).__init__()
        self.residuals = False if conv_sz % 2 == 0 else residuals

        for i in range(n_layers):
            n_in = n_input_features if i==0 else n_output_features
            n_out = (1 + (skip and i==(n_layers-1))) * n_output_features
            layer = _UNetLayer(n_input_features=n_in,
                               n_output_features=n_out,
                               conv_sz=conv_sz,
                               norm_type=norm_type,
                               activ_fn=activ_fn,
                               drop=drop,
                               X=X,
            )
            self.add_module('ConvLayer%d' % (i+1), layer)

    def forward(self, x):
        for name, layer in self.items():
            res = x
            x = layer(x)
            if self.residuals and name[-1]!='1':  x += res
        return x

    
###
class _UNetBlockTranspose(nn.ModuleDict):
    def __init__(self,
                 X:int,
                 n_input_features:int,
                 n_output_features:int,
                 n_layers:int,
                 conv_sz:int,
                 norm_type:str,
                 activ_fn:str,
                 level:int,
                 residuals:bool=False,
                 drop=0,
                 skip=True,
                 **kwargs
    ):
        """
        Decoding UNet block
        """
        super(_UNetBlockTranspose, self).__init__()
        self.printout = False
        self.residuals = residuals if conv_sz % 2 == 0 else False

        for i in range(n_layers):
            n_in = (1 + (skip and i==0)) * n_input_features
            n_out = n_output_features if i==(n_layers-1) else n_input_features
            layer = _UNetLayerTranspose(n_input_features=n_in,
                                        n_output_features=n_out,
                                        conv_sz=conv_sz,
                                        norm_type=norm_type,
                                        activ_fn=activ_fn,
                                        drop=drop,
                                        X=X,
            )
            self.add_module('ConvLayer%d' % (i+1), layer)

    def forward(self, x):
        for name, layer in self.items():
            res = x
            x = layer(x)
            if self.residuals and name[-1]!='1':  x += res
        return x


    
#-----------------------------------------------------------------------------#
#                                     End                                     #
#-----------------------------------------------------------------------------#

###
if __name__ == "__main__":
    main()
