#!/usr/bin/python3

import os
import sys
import platform
import csv
import glob
import time
import argparse
import traceback
import surfa as sf
import numpy as np
import nibabel as nib
from datetime import timedelta
from scipy.ndimage import label as scipy_label
from scipy.interpolate import RegularGridInterpolator
from scipy.ndimage import binary_dilation, distance_transform_edt, gaussian_filter

# set tensorflow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend
backend.set_image_data_format('channels_last')


# ================================================================================================
#                                         Main Entrypoint
# ================================================================================================


def main():

    # parse arguments
    parser = argparse.ArgumentParser(description="This module can be run in two modes: a) on FreeSurfer subjects, and "
                                                 "b) on any T1-weighted scan(s) of approximatively 1mm resolution. ",
                                     epilog='\n')

    # FreeSurfer mode
    parser.add_argument("--s", nargs='*',
                        help="(required in FS mode) Name of one or several subjects in $SUBJECTS_DIR on which to run "
                             "mri_segment_hypothalamic_subunits, "
                             "assuming recon-all has been run on the specified subjects. "
                             "The output segmentations will automatically be saved in each subject's mri folder. "
                             "If no argument is given, mri_segment_hypothalamic_subunits will run on all the subjects "
                             "in $SUBJECTS_DIR.")
    parser.add_argument("--sd", help="(FS mode, optional) override current $SUBJECTS_DIR")
    parser.add_argument("--write_posteriors", action="store_true", help="(FS mode, optional) save posteriors, "
                                                                        "default is False")

    # normal mode
    parser.add_argument("--i", help="(required in T1 mode) Image(s) to segment. "
                                    "Can be a path to a single image or to a folder.")
    parser.add_argument("--o", help="(required in T1 mode) Segmentation output(s). "
                                    "Must be a folder if --i designates a folder.")
    parser.add_argument("--post", help="(T1 mode, optional) Posteriors output(s). "
                                       "Must be a folder if --i designates a folder.")
    parser.add_argument("--resample", help="(T1 mode, optional) Resampled image(s). "
                                           "Must be a folder if --i designates a folder.")
    parser.add_argument("--vol", help="(T1 mode, optional) "
                                      "Output CSV file with volumes for all structures and subjects.")

    # in both cases
    parser.add_argument("--crop", nargs='+', type=int, default=None, dest="crop",
                        help="(both modes, optional) Size of the central patch to analyse (must be divisible by 8). "
                             "The whole image is analysed by default.")
    parser.add_argument("--threads", type=int, default=1, help="(both modes, optional) Number of cores to be used. "
                                                               "Default uses 1 core.")
    parser.add_argument("--cpu", action="store_true", help="(both modes, optional) enforce running with CPU rather "
                                                           "than GPU.")

    # check for no arguments
    if len(sys.argv) < 2:
        parser.print_help()
        sys.exit(1)

    # parse commandline
    args = parser.parse_args()

    # enforce CPU processing if necessary
    if args.cpu:
        print('using CPU, hiding all CUDA_VISIBLE_DEVICES')
        os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

    # limit the number of threads to be used if running on CPU
    if args.threads == 1:
        print('using 1 thread')
    else:
        print('using %s threads' % args.threads)
    tf.config.threading.set_inter_op_parallelism_threads(args.threads)
    tf.config.threading.set_intra_op_parallelism_threads(args.threads)

    # check that freesurfer has been sourced
    if not os.environ.get('FREESURFER_HOME'):
        sf.system.fatal('FREESURFER_HOME is not set. Please source freesurfer.')
    synthseg_home = os.environ.get('FREESURFER_HOME')

    # path model
    model = os.path.join(synthseg_home, 'models', 'hypothalamic_subunits.h5')

    # run prediction
    predict(
        name_subjects=args.s,
        path_subjects_dir=args.sd,
        write_posteriors_FS=args.write_posteriors,
        path_images=args.i,
        path_segmentations=args.o,
        path_model=model,
        path_posteriors=args.post,
        path_resampled=args.resample,
        path_volumes=args.vol,
        crop=args.crop,
    )


# ================================================================================================
#                                 Prediction and Processing Utilities
# ================================================================================================


def predict(name_subjects=None,
            path_subjects_dir=None,
            write_posteriors_FS=False,
            path_images=None,
            path_segmentations=None,
            path_model='../data/model.h5',
            path_posteriors=None,
            path_resampled=None,
            path_volumes=None,
            crop=184):
    '''
    Prediction pipeline.
    '''

    if path_model is None:
        sf.system.fatal("A model file is necessary")

    # prepare input/output filepaths
    outputs = prepare_output_files(name_subjects, path_subjects_dir, write_posteriors_FS, path_images,
                                   path_segmentations, path_posteriors, path_resampled, path_volumes)
    path_images = outputs[0]
    path_segmentations = outputs[1]
    path_posteriors = outputs[2]
    path_resampled = outputs[3]
    path_volumes = outputs[4]
    path_main_volumes = outputs[5]
    path_stats = outputs[6]

    # get label lists
    labels_segmentation = np.concatenate([np.zeros(1, dtype='int32'), np.arange(801, 811, dtype='int32')])

    # prepare volume file if needed
    if path_main_volumes is not None:
        write_csv_file(volumes=None, filename=path_main_volumes, subject=None, write_header=True, open_type='w')

    # build network
    _, _, n_dims, n_channels, _, _ = get_volume_info(path_images[0])
    model_input_shape = [None] * n_dims + [n_channels]
    net = build_model(path_model, model_input_shape, len(labels_segmentation))

    # perform segmentation
    if len(path_images) <= 10:
        loop_info = LoopInfo(len(path_images), 1, 'predicting', True)
    else:
        loop_info = LoopInfo(len(path_images), 10, 'predicting', True)
    list_errors = list()
    for i in range(len(path_images)):
        loop_info.update(i)

        try:

            # preprocessing
            image, aff, h, im_res, shape, crop_idx = preprocess(path_images[i], crop, path_resample=path_resampled[i])

            # prediction
            prediction_patch = net.predict(image)

            # postprocessing
            seg, posteriors = postprocess(prediction_patch, shape, crop_idx, labels_segmentation, aff)

            # write predictions to disc
            save_volume(seg, aff, h, path_segmentations[i], dtype='int32')
            if path_posteriors[i] is not None:
                save_volume(posteriors, aff, h, path_posteriors[i], dtype='float32')

            # write volumes to disc if necessary
            if (path_main_volumes is not None) | (path_volumes[i] is not None):  # compute volumes only if necessary
                volumes = np.sum(posteriors[..., 1:], axis=tuple(range(0, len(posteriors.shape) - 1)))
                volumes = np.around(volumes * np.prod(im_res), 3)
                volumes_whole = np.array([np.sum(volumes[:int(len(volumes) / 2)]),
                                          np.sum(volumes[int(len(volumes) / 2):])])
                volumes = np.concatenate([volumes, volumes_whole])
                if name_subjects is None:  # any T1 mode
                    subject_name = os.path.basename(path_images[i]).replace('.nii.gz', '')
                else:  # FS mode
                    subject_name = os.path.basename(os.path.dirname(os.path.dirname(path_images[i])))
                if path_main_volumes is not None:  # append volumes to main file (regrouping all subjects)
                    write_csv_file(volumes, path_main_volumes, subject_name, write_header=False, open_type='a')
                if path_volumes[i] is not None:  # create individual volume file in each subject subdirectory (FS mode)
                    write_csv_file(volumes, path_volumes[i], subject_name, write_header=True, open_type='w')
                if path_stats[i] is not None:  # create individual stats file in each subject subdirectory (FS mode)
                    write_fs_stats_file(volumes, path_stats[i])

        except Exception as e:
            list_errors.append(path_images[i])
            print('\nthe following problem occured with image %s :' % path_images[i])
            print(traceback.format_exc())
            print('resuming program execution\n')
            continue

    # print output info
    if (len(path_segmentations) == 1) & (len(list_errors) == 0):  # only one image is processed with no error
        print('\nsegmentation  saved in:    ' + path_segmentations[0])
        if path_posteriors[0] is not None:
            print('posteriors saved in:       ' + path_posteriors[0])
        if path_volumes[0] is not None:  # for FS subject
            print('volumes saved in:          ' + path_volumes[0])
        if path_resampled[0] is not None:
            print('resampled images saved in: ' + path_resampled[0])
        if path_main_volumes is not None:  # for single image
            print('volumes saved in:          ' + path_main_volumes)
    elif (len(path_segmentations) > 1) & (len(list_errors) < len(path_segmentations)):  # at least 1 image with no error
        if name_subjects is None:  # images in folder
            print('\nsegmentations saved in:    ' + os.path.dirname(path_segmentations[0]))
            if path_posteriors[0] is not None:
                print('posteriors saved in:       ' + os.path.dirname(path_posteriors[0]))
            if path_resampled[0] is not None:
                print('resampled images saved in: ' + os.path.dirname(path_resampled[0]))
            if path_main_volumes is not None:
                print('volumes saved in:          ' + path_main_volumes)
        else:  # several subjects
            if path_posteriors[0] is not None:
                print('\nsegmentations, posteriors, resampled images, and individual subject volumes '
                      'saved in each subject directory')
            else:
                print('\nsegmentations, resampled images, and individual subject volumes '
                      'saved in each subject directory')
            print('additional file regrouping the volumes of all subjects saved in: ' + path_main_volumes)

    print('\nIf you use this tool in a publication, please cite:')
    print('Automated segmentation of the hypothalamus and associated subunits in brain MRI')
    print('B. Billot, M. Bocchetta, E. Todd, A. V. Dalca, J. D. Rohrer, J. E. Iglesias')
    print('NeuroImage 2020')

    if len(list_errors) > 0:
        print('\nERROR: some problems occured for the following inputs (see corresponding errors above):')
        for path_error_image in list_errors:
            print(path_error_image)
        sys.exit(1)


def prepare_output_files(name_subjects, subjects_dir, write_posteriors_FS, path_images, out_seg, out_posteriors,
                         out_resampled, main_volumes):
    '''
    Prepare output files.
    '''

    # T1 mode
    if path_images is not None:

        # check other inputs
        if out_seg is None:
            sf.system.fatal('please specify an output file/folder (--o) when using flag --i')
        if name_subjects is not None:
            sf.system.fatal('please choose between flags --i and --s, they cannot be used at the same time')
        if subjects_dir is not None:
            print('WARNING: $SUBJECT_DIR not used when flags --i and --o are specified, ignoring value of flag --sd')
        if write_posteriors_FS:
            print('WARNING: flag --write_posteriors not used whith flag --i, ignoring flag --write_posteriors.')
            print('WARNING: If you wish to write the posteriors in the T1 mode, '
                  'please use --post flag instead.')

        # convert path to absolute paths
        path_images = os.path.abspath(path_images)
        basename = os.path.basename(path_images)
        out_seg = os.path.abspath(out_seg)
        out_posteriors = os.path.abspath(out_posteriors) if (out_posteriors is not None) else out_posteriors
        out_resampled = os.path.abspath(out_resampled) if (out_resampled is not None) else out_resampled
        main_volumes = os.path.abspath(main_volumes) if (main_volumes is not None) else main_volumes

        # path_images is a folder
        if ('.nii.gz' not in basename) & ('.nii' not in basename) & ('.mgz' not in basename) & ('.npz' not in basename):

            # input images
            if os.path.isfile(path_images):
                sf.system.fatal('Extension not supported for %s, only use: .nii.gz, .nii, .mgz, or .npz' % path_images)
            path_images = list_images_in_folder(path_images)

            # segmentations
            if (out_seg[-7:] == '.nii.gz') | (out_seg[-4:] == '.nii') | \
                    (out_seg[-4:] == '.mgz') | (out_seg[-4:] == '.npz'):
                sf.system.fatal('Output folders cannot have extensions: .nii.gz, .nii, .mgz, or .npz, had %s' % out_seg)
            mkdir(out_seg)
            out_seg = [os.path.join(out_seg, os.path.basename(image)).replace('.nii', '_hypo_seg.nii') for image in
                       path_images]
            out_seg = [seg_path.replace('.mgz', '_hypo_seg.mgz') for seg_path in out_seg]
            out_seg = [seg_path.replace('.npz', '_hypo_seg.npz') for seg_path in out_seg]

            # posteriors
            if out_posteriors is not None:
                if (out_posteriors[-7:] == '.nii.gz') | (out_posteriors[-4:] == '.nii') | \
                        (out_posteriors[-4:] == '.mgz') | (out_posteriors[-4:] == '.npz'):
                    sf.system.fatal('Output folders cannot have extensions: '
                                    '.nii.gz, .nii, .mgz, or .npz, had %s' % out_posteriors)
                mkdir(out_posteriors)
                out_posteriors = [os.path.join(out_posteriors, os.path.basename(image)).replace('.nii',
                                  '_posteriors.nii') for image in path_images]
                out_posteriors = [posteriors_path.replace('.mgz', '_posteriors.mgz')
                                  for posteriors_path in out_posteriors]
                out_posteriors = [posteriors_path.replace('.npz', '_posteriors.npz')
                                  for posteriors_path in out_posteriors]
            else:
                out_posteriors = [out_posteriors] * len(path_images)

            # resampled
            if out_resampled is not None:
                if (out_resampled[-7:] == '.nii.gz') | (out_resampled[-4:] == '.nii') | \
                        (out_resampled[-4:] == '.mgz') | (out_resampled[-4:] == '.npz'):
                    sf.system.fatal('Output folders cannot have extensions: '
                                    '.nii.gz, .nii, .mgz, or .npz, had %s' % out_resampled)
                mkdir(out_resampled)
                out_resampled = [os.path.join(out_resampled, os.path.basename(image)).replace('.nii',
                                 '_resampled.nii') for image in path_images]
                out_resampled = [resampled_path.replace('.mgz', '_resampled.mgz') for resampled_path in out_resampled]
                out_resampled = [resampled_path.replace('.npz', '_resampled.npz') for resampled_path in out_resampled]
            else:
                out_resampled = [out_resampled] * len(path_images)

        # path_images is an image
        else:

            # input images
            if not os.path.isfile(path_images):
                sf.system.fatal("file does not exist: %s \n"
                                "please make sure the path and the extension are correct" % path_images)
            path_images = [path_images]

            # segmentations
            if ('.nii.gz' not in out_seg) & ('.nii' not in out_seg) & ('.mgz' not in out_seg) & ('.npz' not in out_seg):
                mkdir(out_seg)
                filename = os.path.basename(path_images[0]).replace('.nii', '_hypo_seg.nii')
                filename = filename.replace('.mgz', '_hypo_seg.mgz')
                filename = filename.replace('.npz', '_hypo_seg.npz')
                out_seg = os.path.join(out_seg, filename)
            else:
                mkdir(os.path.dirname(out_seg))
            out_seg = [out_seg]

            # posteriors
            if out_posteriors is not None:
                if ('.nii.gz' not in out_posteriors) & ('.nii' not in out_posteriors) & \
                        ('.mgz' not in out_posteriors) & ('.npz' not in out_posteriors):
                    mkdir(out_posteriors)
                    filename = os.path.basename(path_images[0]).replace('.nii', '_posteriors.nii')
                    filename = filename.replace('.mgz', '_posteriors.mgz')
                    filename = filename.replace('.npz', '_posteriors.npz')
                    out_posteriors = os.path.join(out_posteriors, filename)
                else:
                    mkdir(os.path.dirname(out_posteriors))
            out_posteriors = [out_posteriors]

            # resampled
            if out_resampled is not None:
                if ('.nii.gz' not in out_resampled) & ('.nii' not in out_resampled) & \
                        ('.mgz' not in out_resampled) & ('.npz' not in out_resampled):
                    mkdir(out_resampled)
                    filename = os.path.basename(path_images[0]).replace('.nii', '_resampled.nii')
                    filename = filename.replace('.mgz', '_resampled.mgz')
                    filename = filename.replace('.npz', '_resampled.npz')
                    out_resampled = os.path.join(out_resampled, filename)
                else:
                    mkdir(os.path.dirname(out_resampled))
            out_resampled = [out_resampled]

        # volumes
        if main_volumes is not None:
            if main_volumes[-4:] != '.csv':
                print('Path for volume outputs provided without csv extension. Adding csv extension.')
                main_volumes += '.csv'
                mkdir(os.path.dirname(main_volumes))
        out_volumes = [None] * len(path_images)
        out_stats = [None] * len(path_images)

    # FS mode, run on either a subset of subjects, or on all subjects, with an option to override $SUBJECTS_DIR
    elif name_subjects is not None:

        # check whether addional flags have been provided
        if out_seg is not None:
            print('WARNING: in FS mode segmentations are automatically saved in each subject directory, '
                  'ignoring value provided in --o')
        if main_volumes is not None:
            print('WARNING: in FS mode volumes are automatically saved in each subject directory, '
                  'ignoring value provided in --vol')
        if write_posteriors_FS:
            if out_posteriors is not None:
                print('WARNING: posteriors will automatically be saved in specified subject directory, '
                      'ignoring value provided in --post')
            else:
                print("Posteriors will be saved in each subject's directory.")
        elif out_posteriors is not None:  # write_posteriors is False and out_posteriros was specified
            print('WARNING: flag --post not used in FS mode, ignoring flag --post.')
            print('WARNING: If you wish to write the posteriors in the FS mode, '
                  'please append --write_posteriors to your command line.')

        # override SUBJECTS_DIR if necessary
        if subjects_dir is not None:
            subjects_dir = os.path.abspath(subjects_dir)
            if not os.path.isdir(subjects_dir):
                sf.system.fatal("Could not find " + subjects_dir)
            os.environ['SUBJECTS_DIR'] = subjects_dir
        else:
            subjects_dir = os.environ['SUBJECTS_DIR']

        # list subjects in SUBJECTS_DIR if none were given
        if not name_subjects:
            name_subjects = list_subfolders(subjects_dir, whole_path=False)

        path_images = list()
        out_seg = list()
        out_posteriors = list()
        out_resampled = list()
        out_volumes = list()
        out_stats = list()
        for name in name_subjects:

            # check that provided subject dir
            subject_dir = os.path.join(subjects_dir, name)
            if not os.path.isdir(subject_dir):
                sf.system.fatal("Could not find subject dir " + subject_dir)
            path_image = os.path.join(subject_dir, 'mri', 'nu.mgz')

            # build paths if input image exists
            if os.path.isfile(path_image):
                path_images.append(path_image)
                out_seg.append(os.path.join(subject_dir, 'mri', 'hypothalamic_subunits_seg.v1.mgz'))
                out_resampled.append(os.path.join(subject_dir, 'mri', 'hypothalamic_subunits_nu_resampled_1mm.v1.mgz'))
                out_volumes.append(os.path.join(subject_dir, 'mri', 'hypothalamic_subunits_volumes.v1.csv'))
                out_stats.append(os.path.join(subject_dir, 'stats', 'hypothalamic_subunits_volumes.v1.stats'))
                if write_posteriors_FS:
                    out_posteriors.append(os.path.join(subject_dir, 'mri', 'hypothalamic_subunits_posteriors.v1.mgz'))
                else:
                    out_posteriors.append(None)
            else:
                print('WARNING: no such file: ' + path_image + ', continuing program execution without this file')

        # check that we have a least one valid image
        if len(path_images) == 0:
            sf.system.fatal("Could not find any image to segment")

        # if several subjects are run, all volumes are regrouped in a single file
        if len(path_images) > 1:
            main_volumes = os.path.join(subjects_dir, 'hypothalamic_subunits_volumes_all.v1.csv')
        else:
            main_volumes = None

    else:
        sf.system.fatal('please provide an input image/directory (--i), '
                        'or a specific FreeSurfer subject directory (--s), '
                        'or the FreeSurfer $SUBJECT_DIR (--sd)')
        path_images = out_volumes = out_stats = None

    return path_images, out_seg, out_posteriors, out_resampled, out_volumes, main_volumes, out_stats


def preprocess(path_image, crop=None, n_levels=3, path_resample=None):

    # read image and corresponding info
    im, shape, aff, n_dims, n_channels, h, im_res = get_volume_info(path_image, True)
    if n_dims < 3:
        sf.system.fatal('input should have 3 dimensions, had %s' % n_dims)
    elif n_dims == 4 and n_channels == 1:
        n_dims = 3
        im = im[..., 0]
    elif n_dims > 3:
        sf.system.fatal('input should have 3 dimensions, had %s' % n_dims)
    elif n_channels > 1:
        print('WARNING: detected more than 1 channel, only keeping the first channel.')
        im = im[..., 0]

    # resample image if necessary
    if np.any((im_res > np.array([1.15]*3)) | (im_res < np.array([0.95]*3))):
        im_res = np.array([1.]*3)
        im, aff = resample_volume(im, aff, im_res)
        if path_resample is not None:
            save_volume(im, aff, h, path_resample)

    # align image
    im = align_volume_to_ref(im, aff, aff_ref=np.eye(4), n_dims=n_dims, return_copy=False)
    shape = list(im.shape[:n_dims])

    # check that shape is divisible by 2**n_levels
    if crop is None:
        crop = shape.copy()
    crop = reformat_to_list(crop, length=n_dims, dtype='int')
    if not all([shape[i] >= crop[i] for i in range(n_dims)]):
        crop = [min(shape[i], crop[i]) for i in range(n_dims)]
    if not all([size % (2**n_levels) == 0 for size in crop]):
        crop = [find_closest_number_divisible_by_m(size, 2 ** n_levels) for size in crop]

    # crop image
    im, crop_idx = crop_volume(im, cropping_shape=crop, return_crop_idx=True)

    # normalise image
    im = rescale_volume(im, new_min=0, new_max=1, min_percentile=0.5, max_percentile=99.5)

    # add batch and channel axes
    im = add_axis(im, axis=[0, -1])

    return im, aff, h, im_res, shape, crop_idx


def build_model(model_file, input_shape, n_lab):

    net = unet(nb_features=24,
               input_shape=input_shape,
               nb_levels=3,
               conv_size=3,
               nb_labels=n_lab,
               feat_mult=2,
               nb_conv_per_level=2,
               batch_norm=-1)
    net.load_weights(model_file, by_name=True)

    return net


def postprocess(prediction, im_shape, crop_idx, labels, aff):

    # get posteriors and segmentation
    post_patch = np.squeeze(prediction)
    seg_patch = post_patch.argmax(-1)

    # further crop the seg_patch around predicted values
    seg_patch_cropped, patch_crop = crop_volume_around_region(seg_patch, threshold=0, margin=2)
    post_patch_cropped = crop_volume_with_idx(post_patch, patch_crop, n_dims=3)

    # keep biggest connected component
    left_mask = get_largest_connected_component((seg_patch_cropped > 0) & (seg_patch_cropped < 6))
    right_mask = get_largest_connected_component(seg_patch_cropped > 5)
    seg_patch_cropped *= (left_mask | right_mask)

    # mask posteriors of each label
    dilate_struct = build_binary_structure(1, 3)
    for i in range(1, len(labels)):
        tmp_post_patch_cropped = post_patch_cropped[..., i]
        tmp_mask = binary_dilation(seg_patch_cropped == i, dilate_struct)
        tmp_post_patch_cropped[np.logical_not(tmp_mask)] = 0
        post_patch_cropped[..., i] = tmp_post_patch_cropped
    post_patch_cropped[..., 0] = 1 - np.sum(post_patch_cropped[..., 1:], axis=-1)  # renormalise

    # paste patches back to matrices of original image size
    seg = np.zeros(shape=im_shape, dtype='int32')
    posteriors = np.zeros(shape=[*im_shape, labels.shape[0]])
    posteriors[..., 0] = np.ones(im_shape)  # set background posteriors to 1
    crop_idx = patch_crop + np.tile(crop_idx[:int(np.array(crop_idx).shape[0] / 2)], 2)
    seg[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5]] = seg_patch_cropped
    posteriors[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], :] = post_patch_cropped
    seg = labels[seg.astype('int32')].astype('int32')

    # align prediction back to first orientation
    seg = align_volume_to_ref(seg, aff=np.eye(4), aff_ref=aff, n_dims=3, return_copy=False)
    posteriors = align_volume_to_ref(posteriors, np.eye(4), aff_ref=aff, n_dims=3, return_copy=False)

    return seg, posteriors


def write_csv_file(volumes, filename, subject, write_header, open_type):
    """Writes volumetric stats to csv file.
    :param volumes: numpy array with volumes for all structures specified in below header
    :param filename: path of csv file where volumes will be saved
    :param subject: subject name
    :param write_header: whether to write the header or not
    :param open_type: can be 'w' (write new file), or 'a' (add volume to an already existing file)"""

    header = ['subject',
              'left anterior-inferior',
              'left anterior-superior',
              'left posterior',
              'left tubular inferior',
              'left tubular superior',
              'right anterior-inferior',
              'right anterior-superior',
              'right posterior',
              'right tubular inferior',
              'right tubular superior',
              'whole left',
              'whole right']
    with open(filename, open_type) as csv_file:
        writer = csv.writer(csv_file)
        if write_header:
            writer.writerow(header)
        if volumes is not None:
            datarow = [subject] + ['%.3f' % vol for vol in volumes]
            writer.writerow(datarow)
    csv_file.close()


def write_fs_stats_file(volumes, filename):
    """Write volumetric stats to FS stats file."""

    segnames = ['Left-Anterior-Inferior',
                'Left-Anterior-Superior',
                'Left-Posterior',
                'Left-Tubular-Inferior',
                'Left-Tubular-Superior',
                'Right-Anterior-Inferior',
                'Right-Anterior-Ssuperior',
                'Right-Posterior',
                'Right-Tubular-Inferior',
                'Right-Tubular-Superior',
                'Whole-Left',
                'Whole-Right']
    volumes = ['%.3f' % vol for vol in volumes]
    volwidth = len(max(volumes, key=len))
    with open(filename, 'w') as f:
        f.write('# Hypothalamic Subunit Volumetric Stats\n')
        for i, vol in enumerate(volumes):
            f.write('%s  %s  0  %s  %s\n' % (
                str(i + 1).rjust(2),
                str(i + 1).rjust(2),
                vol.rjust(volwidth),
                segnames[i],
            ))


# ================================================================================================
#                       Neurite Utilities - See github.com/adalca/neurite
# ================================================================================================


def unet(nb_features,
         input_shape,
         nb_levels,
         conv_size,
         nb_labels,
         name='unet',
         prefix=None,
         feat_mult=1,
         pool_size=2,
         padding='same',
         dilation_rate_mult=1,
         activation='elu',
         skip_n_concatenations=0,
         use_residuals=False,
         final_pred_activation='softmax',
         nb_conv_per_level=1,
         layer_nb_feats=None,
         conv_dropout=0,
         batch_norm=None,
         input_model=None):
    """
    Unet-style keras model with an overdose of parametrization. Copied with permission
    from github.com/adalca/neurite.
    """

    # naming
    model_name = name
    if prefix is None:
        prefix = model_name

    # volume size data
    ndims = len(input_shape) - 1
    if isinstance(pool_size, int):
        pool_size = (pool_size,) * ndims

    # get encoding model
    enc_model = conv_enc(nb_features,
                         input_shape,
                         nb_levels,
                         conv_size,
                         name=model_name,
                         prefix=prefix,
                         feat_mult=feat_mult,
                         pool_size=pool_size,
                         padding=padding,
                         dilation_rate_mult=dilation_rate_mult,
                         activation=activation,
                         use_residuals=use_residuals,
                         nb_conv_per_level=nb_conv_per_level,
                         layer_nb_feats=layer_nb_feats,
                         conv_dropout=conv_dropout,
                         batch_norm=batch_norm,
                         input_model=input_model)

    # get decoder
    # use_skip_connections=True makes it a u-net
    lnf = layer_nb_feats[(nb_levels * nb_conv_per_level):] if layer_nb_feats is not None else None
    dec_model = conv_dec(nb_features,
                         None,
                         nb_levels,
                         conv_size,
                         nb_labels,
                         name=model_name,
                         prefix=prefix,
                         feat_mult=feat_mult,
                         pool_size=pool_size,
                         use_skip_connections=True,
                         skip_n_concatenations=skip_n_concatenations,
                         padding=padding,
                         dilation_rate_mult=dilation_rate_mult,
                         activation=activation,
                         use_residuals=use_residuals,
                         final_pred_activation=final_pred_activation,
                         nb_conv_per_level=nb_conv_per_level,
                         batch_norm=batch_norm,
                         layer_nb_feats=lnf,
                         conv_dropout=conv_dropout,
                         input_model=enc_model)
    final_model = dec_model

    return final_model


def conv_enc(nb_features,
             input_shape,
             nb_levels,
             conv_size,
             name=None,
             prefix=None,
             feat_mult=1,
             pool_size=2,
             dilation_rate_mult=1,
             padding='same',
             activation='elu',
             layer_nb_feats=None,
             use_residuals=False,
             nb_conv_per_level=2,
             conv_dropout=0,
             batch_norm=None,
             input_model=None):
    """
    Fully Convolutional Encoder. Copied with permission from github.com/adalca/neurite.
    """

    # naming
    model_name = name
    if prefix is None:
        prefix = model_name

    # first layer: input
    name = '%s_input' % prefix
    if input_model is None:
        input_tensor = keras.layers.Input(shape=input_shape, name=name)
        last_tensor = input_tensor
    else:
        input_tensor = input_model.inputs
        last_tensor = input_model.outputs
        if isinstance(last_tensor, list):
            last_tensor = last_tensor[0]

    # volume size data
    ndims = len(input_shape) - 1
    if isinstance(pool_size, int):
        pool_size = (pool_size,) * ndims

    # prepare layers
    convL = getattr(keras.layers, 'Conv%dD' % ndims)
    conv_kwargs = {'padding': padding, 'activation': activation, 'data_format': 'channels_last'}
    maxpool = getattr(keras.layers, 'MaxPooling%dD' % ndims)
    fused_batch_norm = False if (platform.system() == 'Darwin' and platform.machine() == 'arm64') else None    

    # down arm:
    # add nb_levels of conv + ReLu + conv + ReLu. Pool after each of first nb_levels - 1 layers
    lfidx = 0  # level feature index
    for level in range(nb_levels):
        lvl_first_tensor = last_tensor
        nb_lvl_feats = np.round(nb_features * feat_mult ** level).astype(int)
        conv_kwargs['dilation_rate'] = dilation_rate_mult ** level

        for conv in range(nb_conv_per_level):  # does several conv per level, max pooling applied at the end
            if layer_nb_feats is not None:  # None or List of all the feature numbers
                nb_lvl_feats = layer_nb_feats[lfidx]
                lfidx += 1

            name = '%s_conv_downarm_%d_%d' % (prefix, level, conv)
            if conv < (nb_conv_per_level - 1) or (not use_residuals):
                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
            else:  # no activation
                last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)

            if conv_dropout > 0:
                # conv dropout along feature space only
                name = '%s_dropout_downarm_%d_%d' % (prefix, level, conv)
                noise_shape = [None, *[1] * ndims, nb_lvl_feats]
                last_tensor = keras.layers.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)

        if use_residuals:
            convarm_layer = last_tensor

            # the "add" layer is the original input
            # However, it may not have the right number of features to be added
            nb_feats_in = lvl_first_tensor.get_shape()[-1]
            nb_feats_out = convarm_layer.get_shape()[-1]
            add_layer = lvl_first_tensor
            if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
                name = '%s_expand_down_merge_%d' % (prefix, level)
                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(lvl_first_tensor)
                add_layer = last_tensor

                if conv_dropout > 0:
                    name = '%s_dropout_down_merge_%d_%d' % (prefix, level, conv)
                    noise_shape = [None, *[1] * ndims, nb_lvl_feats]

            name = '%s_res_down_merge_%d' % (prefix, level)
            last_tensor = keras.layers.add([add_layer, convarm_layer], name=name)

            name = '%s_res_down_merge_act_%d' % (prefix, level)
            last_tensor = keras.layers.Activation(activation, name=name)(last_tensor)

        if batch_norm is not None:
            name = '%s_bn_down_%d' % (prefix, level)
            last_tensor = keras.layers.BatchNormalization(axis=batch_norm, name=name, fused=fused_batch_norm)(last_tensor)

        # max pool if we're not at the last level
        if level < (nb_levels - 1):
            name = '%s_maxpool_%d' % (prefix, level)
            last_tensor = maxpool(pool_size=pool_size, name=name, padding=padding)(last_tensor)

    # create the model and return
    model = keras.Model(inputs=input_tensor, outputs=[last_tensor], name=model_name)
    return model


def conv_dec(nb_features,
             input_shape,
             nb_levels,
             conv_size,
             nb_labels,
             name=None,
             prefix=None,
             feat_mult=1,
             pool_size=2,
             use_skip_connections=False,
             skip_n_concatenations=0,
             padding='same',
             dilation_rate_mult=1,
             activation='elu',
             use_residuals=False,
             final_pred_activation='softmax',
             nb_conv_per_level=2,
             layer_nb_feats=None,
             batch_norm=None,
             conv_dropout=0,
             input_model=None):
    """
    Fully Convolutional Decoder. Copied with permission from github.com/adalca/neurite.

    Parameters:
        ...
        use_skip_connections (bool): if true, turns an Enc-Dec to a U-Net.
            If true, input_tensor and tensors are required.
            It assumes a particular naming of layers. conv_enc...
    """

    # naming
    model_name = name
    if prefix is None:
        prefix = model_name

    # if using skip connections, make sure need to use them.
    if use_skip_connections:
        assert input_model is not None, "is using skip connections, tensors dictionary is required"

    # first layer: input
    input_name = '%s_input' % prefix
    if input_model is None:
        input_tensor = keras.layers.Input(shape=input_shape, name=input_name)
        last_tensor = input_tensor
    else:
        input_tensor = input_model.input
        last_tensor = input_model.output
        input_shape = last_tensor.shape.as_list()[1:]

    # vol size info
    ndims = len(input_shape) - 1
    if isinstance(pool_size, int):
        if ndims > 1:
            pool_size = (pool_size,) * ndims

    # prepare layers
    convL = getattr(keras.layers, 'Conv%dD' % ndims)
    conv_kwargs = {'padding': padding, 'activation': activation}
    upsample = getattr(keras.layers, 'UpSampling%dD' % ndims)
    fused_batch_norm = False if (platform.system() == 'Darwin' and platform.machine() == 'arm64') else None    

    # up arm:
    # nb_levels - 1 layers of Deconvolution3D
    #    (approx via up + conv + ReLu) + merge + conv + ReLu + conv + ReLu
    lfidx = 0
    for level in range(nb_levels - 1):
        nb_lvl_feats = np.round(nb_features * feat_mult ** (nb_levels - 2 - level)).astype(int)
        conv_kwargs['dilation_rate'] = dilation_rate_mult ** (nb_levels - 2 - level)

        # upsample matching the max pooling layers size
        name = '%s_up_%d' % (prefix, nb_levels + level)
        last_tensor = upsample(size=pool_size, name=name)(last_tensor)
        up_tensor = last_tensor

        # merge layers combining previous layer
        if use_skip_connections & (level < (nb_levels - skip_n_concatenations - 1)):
            conv_name = '%s_conv_downarm_%d_%d' % (prefix, nb_levels - 2 - level, nb_conv_per_level - 1)
            cat_tensor = input_model.get_layer(conv_name).output
            name = '%s_merge_%d' % (prefix, nb_levels + level)
            last_tensor = keras.layers.concatenate([cat_tensor, last_tensor], axis=ndims + 1, name=name)

        # convolution layers
        for conv in range(nb_conv_per_level):
            if layer_nb_feats is not None:
                nb_lvl_feats = layer_nb_feats[lfidx]
                lfidx += 1

            name = '%s_conv_uparm_%d_%d' % (prefix, nb_levels + level, conv)
            if conv < (nb_conv_per_level - 1) or (not use_residuals):
                last_tensor = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(last_tensor)
            else:
                last_tensor = convL(nb_lvl_feats, conv_size, padding=padding, name=name)(last_tensor)

            if conv_dropout > 0:
                name = '%s_dropout_uparm_%d_%d' % (prefix, level, conv)
                noise_shape = [None, *[1] * ndims, nb_lvl_feats]
                last_tensor = keras.layers.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)

        # residual block
        if use_residuals:

            # the "add" layer is the original input
            # However, it may not have the right number of features to be added
            add_layer = up_tensor
            nb_feats_in = add_layer.get_shape()[-1]
            nb_feats_out = last_tensor.get_shape()[-1]
            if nb_feats_in > 1 and nb_feats_out > 1 and (nb_feats_in != nb_feats_out):
                name = '%s_expand_up_merge_%d' % (prefix, level)
                add_layer = convL(nb_lvl_feats, conv_size, **conv_kwargs, name=name)(add_layer)

                if conv_dropout > 0:
                    name = '%s_dropout_up_merge_%d_%d' % (prefix, level, conv)
                    noise_shape = [None, *[1] * ndims, nb_lvl_feats]
                    last_tensor = keras.layers.Dropout(conv_dropout, noise_shape=noise_shape, name=name)(last_tensor)

            name = '%s_res_up_merge_%d' % (prefix, level)
            last_tensor = keras.layers.add([last_tensor, add_layer], name=name)

            name = '%s_res_up_merge_act_%d' % (prefix, level)
            last_tensor = keras.layers.Activation(activation, name=name)(last_tensor)

        if batch_norm is not None:
            name = '%s_bn_up_%d' % (prefix, level)
            last_tensor = keras.layers.BatchNormalization(axis=batch_norm, name=name, fused=fused_batch_norm)(last_tensor)

    # Compute likelyhood prediction (no activation yet)
    name = '%s_likelihood' % prefix
    last_tensor = convL(nb_labels, 1, activation=None, name=name)(last_tensor)
    like_tensor = last_tensor

    # output prediction layer
    # we use a softmax to compute P(L_x|I) where x is each location
    if final_pred_activation == 'softmax':
        name = '%s_prediction' % prefix
        softmax_lambda_fcn = lambda x: keras.activations.softmax(x, axis=ndims + 1)
        pred_tensor = keras.layers.Lambda(softmax_lambda_fcn, name=name)(last_tensor)

    # otherwise create a layer that does nothing.
    else:
        name = '%s_prediction' % prefix
        pred_tensor = keras.layers.Activation('linear', name=name)(like_tensor)

    # create the model and retun
    model = keras.Model(inputs=input_tensor, outputs=pred_tensor, name=model_name)
    return model


# ================================================================================================
#                                        Lab2Im Utilities
# ================================================================================================


# ---------------------------------------------- loading/saving functions ----------------------------------------------


def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=None):
    """
    Load volume file.
    :param path_volume: path of the volume to load. Can either be a nii, nii.gz, mgz, or npz format.
    If npz format, 1) the variable name is assumed to be 'vol_data',
    2) the volume is associated with an identity affine matrix and blank header.
    :param im_only: (optional) if False, the function also returns the affine matrix and header of the volume.
    :param squeeze: (optional) whether to squeeze the volume when loading.
    :param dtype: (optional) if not None, convert the loaded volume to this numpy dtype.
    :param aff_ref: (optional) If not None, the loaded volume is aligned to this affine matrix.
    The returned affine matrix is also given in this new space. Must be a numpy array of dimension 4x4.
    :return: the volume, with corresponding affine matrix and header if im_only is False.
    """
    assert path_volume.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % path_volume

    if path_volume.endswith(('.nii', '.nii.gz', '.mgz')):
        x = nib.load(path_volume)
        if squeeze:
            volume = np.squeeze(x.get_fdata())
        else:
            volume = x.get_fdata()
        aff = x.affine
        header = x.header
    else:  # npz
        volume = np.load(path_volume)['vol_data']
        if squeeze:
            volume = np.squeeze(volume)
        aff = np.eye(4)
        header = nib.Nifti1Header()
    if dtype is not None:
        if 'int' in dtype:
            volume = np.round(volume)
        volume = volume.astype(dtype=dtype)

    # align image to reference affine matrix
    if aff_ref is not None:
        n_dims, _ = get_dims(list(volume.shape), max_channels=10)
        volume, aff = align_volume_to_ref(volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims)

    if im_only:
        return volume
    else:
        return volume, aff, header


def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
    """
    Save a volume.
    :param volume: volume to save
    :param aff: affine matrix of the volume to save. If aff is None, the volume is saved with an identity affine matrix.
    aff can also be set to 'FS', in which case the volume is saved with the affine matrix of FreeSurfer outputs.
    :param header: header of the volume to save. If None, the volume is saved with a blank header.
    :param path: path where to save the volume.
    :param res: (optional) update the resolution in the header before saving the volume.
    :param dtype: (optional) numpy dtype for the saved volume.
    :param n_dims: (optional) number of dimensions, to avoid confusion in multi-channel case. Default is None, where
    n_dims is automatically inferred.
    """

    mkdir(os.path.dirname(path))
    if '.npz' in path:
        np.savez_compressed(path, vol_data=volume)
    else:
        if header is None:
            header = nib.Nifti1Header()
        if isinstance(aff, str):
            if aff == 'FS':
                aff = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
        elif aff is None:
            aff = np.eye(4)
        if dtype is not None:
            if 'int' in dtype:
                volume = np.round(volume)
            volume = volume.astype(dtype=dtype)
            nifty = nib.Nifti1Image(volume, aff, header)
            nifty.set_data_dtype(dtype)
        else:
            nifty = nib.Nifti1Image(volume, aff, header)
        if res is not None:
            if n_dims is None:
                n_dims, _ = get_dims(volume.shape)
            res = reformat_to_list(res, length=n_dims, dtype=None)
            nifty.header.set_zooms(res)
        nib.save(nifty, path)


def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10):
    """
    Gather information about a volume: shape, affine matrix, number of dimensions and channels, header, and resolution.
    :param path_volume: path of the volume to get information form.
    :param return_volume: (optional) whether to return the volume along with the information.
    :param aff_ref: (optional) If not None, the loaded volume is aligned to this affine matrix.
    All info relative to the volume is then given in this new space. Must be a numpy array of dimension 4x4.
    :param max_channels: maximum possible number of channels for the input volume.
    :return: volume (if return_volume is true), and corresponding info. If aff_ref is not None, the returned aff is
    the original one, i.e. the affine of the image before being aligned to aff_ref.
    """
    # read image
    im, aff, header = load_volume(path_volume, im_only=False)

    # understand if image is multichannel
    im_shape = list(im.shape)
    n_dims, n_channels = get_dims(im_shape, max_channels=max_channels)
    im_shape = im_shape[:n_dims]

    # get labels res
    if '.nii' in path_volume:
        data_res = np.array(header['pixdim'][1:n_dims + 1])
    elif '.mgz' in path_volume:
        data_res = np.array(header['delta'])  # mgz image
    else:
        data_res = np.array([1.0] * n_dims)

    # align to given affine matrix
    if aff_ref is not None:
        ras_axes = get_ras_axes(aff, n_dims=n_dims)
        ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims)
        im = align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims)
        im_shape = np.array(im_shape)
        data_res = np.array(data_res)
        im_shape[ras_axes_ref] = im_shape[ras_axes]
        data_res[ras_axes_ref] = data_res[ras_axes]
        im_shape = im_shape.tolist()

    # return info
    if return_volume:
        return im, im_shape, aff, n_dims, n_channels, header, data_res
    else:
        return im_shape, aff, n_dims, n_channels, header, data_res


def load_array_if_path(var, load_as_numpy=True):
    """If var is a string and load_as_numpy is True, this function loads the array writen at the path indicated by var.
    Otherwise it simply returns var as it is."""
    if (isinstance(var, str)) & load_as_numpy:
        assert os.path.isfile(var), 'No such path: %s' % var
        var = np.load(var)
    return var


# ----------------------------------------------- reformatting functions -----------------------------------------------


def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None):
    """This function takes a variable and reformat it into a list of desired
    length and type (int, float, bool, str).
    If variable is a string, and load_as_numpy is True, it will be loaded as a numpy array.
    If variable is None, this function returns None.
    :param var: a str, int, float, list, tuple, or numpy array
    :param length: (optional) if var is a single item, it will be replicated to a list of this length
    :param load_as_numpy: (optional) whether var is the path to a numpy array
    :param dtype: (optional) convert all item to this type. Can be 'int', 'float', 'bool', or 'str'
    :return: reformatted list
    """

    # convert to list
    if var is None:
        return None
    var = load_array_if_path(var, load_as_numpy=load_as_numpy)
    if isinstance(var, (int, float, np.int8, np.int16, np.int32, np.int64, np.float16, np.float32, np.float64)):
        var = [var]
    elif isinstance(var, tuple):
        var = list(var)
    elif isinstance(var, np.ndarray):
        if var.shape == (1,):
            var = [var[0]]
        else:
            var = np.squeeze(var).tolist()
    elif isinstance(var, str):
        var = [var]
    elif isinstance(var, bool):
        var = [var]
    if isinstance(var, list):
        if length is not None:
            if len(var) == 1:
                var = var * length
            elif len(var) != length:
                raise ValueError('if var is a list/tuple/numpy array, it should be of length 1 or {0}, '
                                 'had {1}'.format(length, var))
    else:
        raise TypeError('var should be an int, float, tuple, list, numpy array, or path to numpy array')

    # convert items type
    if dtype is not None:
        if dtype == 'int':
            var = [int(v) for v in var]
        elif dtype == 'float':
            var = [float(v) for v in var]
        elif dtype == 'bool':
            var = [bool(v) for v in var]
        elif dtype == 'str':
            var = [str(v) for v in var]
        else:
            raise ValueError("dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype))
    return var


# ----------------------------------------------- path-related functions -----------------------------------------------


def list_images_in_folder(path_dir, include_single_image=True, check_if_empty=True):
    """List all files with extension nii, nii.gz, mgz, or npz within a folder."""
    basename = os.path.basename(path_dir)
    if include_single_image & \
            (('.nii.gz' in basename) | ('.nii' in basename) | ('.mgz' in basename) | ('.npz' in basename)):
        assert os.path.isfile(path_dir), 'file %s does not exist' % path_dir
        list_images = [path_dir]
    else:
        if os.path.isdir(path_dir):
            list_images = sorted(glob.glob(os.path.join(path_dir, '*nii.gz')) +
                                 glob.glob(os.path.join(path_dir, '*nii')) +
                                 glob.glob(os.path.join(path_dir, '*.mgz')) +
                                 glob.glob(os.path.join(path_dir, '*.npz')))
        else:
            raise Exception('Folder does not exist: %s' % path_dir)
        if check_if_empty:
            assert len(list_images) > 0, 'no .nii, .nii.gz, .mgz or .npz image could be found in %s' % path_dir
    return list_images


def list_subfolders(path_dir, whole_path=True, expr=None, cond_type='or'):
    """This function returns a list of subfolders contained in a folder, with possible regexp.
    :param path_dir: path of a folder
    :param whole_path: (optional) whether to return whole path or just the subfolder names.
    :param expr: (optional) regexp for files to list. Can be a str or a list of str.
    :param cond_type: (optional) if exp is a list, specify the logical link between expressions in exp.
    Can be 'or', or 'and'.
    :return: a list of subfolders
    """
    assert isinstance(whole_path, bool), "whole_path should be bool"
    assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'"
    if whole_path:
        subdirs_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir)
                               if os.path.isdir(os.path.join(path_dir, f))])
    else:
        subdirs_list = sorted([f for f in os.listdir(path_dir) if os.path.isdir(os.path.join(path_dir, f))])
    if expr is not None:  # assumed to be either str or list of str
        if isinstance(expr, str):
            expr = [expr]
        elif not isinstance(expr, (list, tuple)):
            raise Exception("if specified, 'expr' should be a string or list of strings.")
        matched_list_subdirs = list()
        for match in expr:
            tmp_matched_list_subdirs = sorted([f for f in subdirs_list if match in os.path.basename(f)])
            if cond_type == 'or':
                subdirs_list = [f for f in subdirs_list if f not in tmp_matched_list_subdirs]
                matched_list_subdirs += tmp_matched_list_subdirs
            elif cond_type == 'and':
                subdirs_list = tmp_matched_list_subdirs
                matched_list_subdirs = tmp_matched_list_subdirs
        subdirs_list = sorted(matched_list_subdirs)
    return subdirs_list


def mkdir(path_dir):
    """Recursively creates the current dir as well as its parent folders if they do not already exist."""
    if len(path_dir) > 0:
        if path_dir[-1] == '/':
            path_dir = path_dir[:-1]
        if not os.path.isdir(path_dir):
            list_dir_to_create = [path_dir]
            while not os.path.isdir(os.path.dirname(list_dir_to_create[-1])):
                list_dir_to_create.append(os.path.dirname(list_dir_to_create[-1]))
            for dir_to_create in reversed(list_dir_to_create):
                os.mkdir(dir_to_create)


# ---------------------------------------------- shape-related functions -----------------------------------------------


def get_dims(shape, max_channels=10):
    """Get the number of dimensions and channels from the shape of an array.
    The number of dimensions is assumed to be the length of the shape, as long as the shape of the last dimension is
    inferior or equal to max_channels (default 3).
    :param shape: shape of an array. Can be a sequence or a 1d numpy array.
    :param max_channels: maximum possible number of channels.
    :return: the number of dimensions and channels associated with the provided shape.
    example 1: get_dims([150, 150, 150], max_channels=10) = (3, 1)
    example 2: get_dims([150, 150, 150, 3], max_channels=10) = (3, 3)
    example 3: get_dims([150, 150, 150, 15], max_channels=10) = (4, 1), because 5>3"""
    if shape[-1] <= max_channels:
        n_dims = len(shape) - 1
        n_channels = shape[-1]
    else:
        n_dims = len(shape)
        n_channels = 1
    return n_dims, n_channels


def add_axis(x, axis=0):
    """Add axis to a numpy array.
    :param x: input array
    :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time."""
    axis = reformat_to_list(axis)
    for ax in axis:
        x = np.expand_dims(x, axis=ax)
    return x


# --------------------------------------------------- miscellaneous ----------------------------------------------------


class LoopInfo:
    """
    Class to print the current iteration in a for loop, and optionally the estimated remaining time.
    Instantiate just before the loop, and call the update method at the start of the loop.
    The printed text has the following format:
    processing i/total    remaining time: hh:mm:ss
    """

    def __init__(self, n_iterations, spacing=10, text='processing', print_time=False):
        """
        :param n_iterations: total number of iterations of the for loop.
        :param spacing: frequency at which the update info will be printed on screen.
        :param text: text to print. Default is processing.
        :param print_time: whether to print the estimated remaining time. Default is False.
        """

        # loop parameters
        self.n_iterations = n_iterations
        self.spacing = spacing

        # text parameters
        self.text = text
        self.print_time = print_time
        self.print_previous_time = False
        self.align = len(str(self.n_iterations)) * 2 + 1 + 3

        # timing parameters
        self.iteration_durations = np.zeros((n_iterations,))
        self.start = time.time()
        self.previous = time.time()

    def update(self, idx):

        # time iteration
        now = time.time()
        self.iteration_durations[idx] = now - self.previous
        self.previous = now

        # print text
        if idx == 0:
            print(self.text + ' 1/{}'.format(self.n_iterations))
        elif idx % self.spacing == self.spacing - 1:
            iteration = str(idx + 1) + '/' + str(self.n_iterations)
            if self.print_time:
                # estimate remaining time
                max_duration = np.max(self.iteration_durations)
                average_duration = np.mean(self.iteration_durations[self.iteration_durations > .01 * max_duration])
                remaining_time = int(average_duration * (self.n_iterations - idx))
                # print total remaining time only if it is greater than 1s or if it was previously printed
                if (remaining_time > 1) | self.print_previous_time:
                    eta = str(timedelta(seconds=remaining_time))
                    print(self.text + ' {:<{x}} remaining time: {}'.format(iteration, eta, x=self.align))
                    self.print_previous_time = True
                else:
                    print(self.text + ' {}'.format(iteration))
            else:
                print(self.text + ' {}'.format(iteration))


def find_closest_number_divisible_by_m(n, m, answer_type='lower'):
    """Return the closest integer to n that is divisible by m. answer_type can either be 'closer', 'lower' (only returns
    values lower than n), or 'higher' (only returns values higher than m)."""
    if n % m == 0:
        return n
    else:
        q = int(n / m)
        lower = q * m
        higher = (q + 1) * m
        if answer_type == 'lower':
            return lower
        elif answer_type == 'higher':
            return higher
        elif answer_type == 'closer':
            return lower if (n - lower) < (higher - n) else higher
        else:
            raise Exception('answer_type should be lower, higher, or closer, had : %s' % answer_type)


def build_binary_structure(connectivity, n_dims, shape=None):
    """Return a dilation/erosion element with provided connectivity"""
    if shape is None:
        shape = [connectivity * 2 + 1] * n_dims
    else:
        shape = reformat_to_list(shape, length=n_dims)
    dist = np.ones(shape)
    center = tuple([tuple([int(s / 2)]) for s in shape])
    dist[center] = 0
    dist = distance_transform_edt(dist)
    struct = (dist <= connectivity) * 1
    return struct


# ---------------------------------------------------- edit volume -----------------------------------------------------

def rescale_volume(volume, new_min=0, new_max=255, min_percentile=2., max_percentile=98., use_positive_only=False):
    """This function linearly rescales a volume between new_min and new_max.
    :param volume: a numpy array
    :param new_min: (optional) minimum value for the rescaled image.
    :param new_max: (optional) maximum value for the rescaled image.
    :param min_percentile: (optional) percentile for estimating robust minimum of volume (float in [0,...100]),
    where 0 = np.min
    :param max_percentile: (optional) percentile for estimating robust maximum of volume (float in [0,...100]),
    where 100 = np.max
    :param use_positive_only: (optional) whether to use only positive values when estimating the min and max percentile
    :return: rescaled volume
    """

    # select only positive intensities
    new_volume = volume.copy()
    intensities = new_volume[new_volume > 0] if use_positive_only else new_volume.flatten()

    # define min and max intensities in original image for normalisation
    robust_min = np.min(intensities) if min_percentile == 0 else np.percentile(intensities, min_percentile)
    robust_max = np.max(intensities) if max_percentile == 100 else np.percentile(intensities, max_percentile)

    # trim values outside range
    new_volume = np.clip(new_volume, robust_min, robust_max)

    # rescale image
    if robust_min != robust_max:
        return new_min + (new_volume - robust_min) / (robust_max - robust_min) * (new_max - new_min)
    else:  # avoid dividing by zero
        return np.zeros_like(new_volume)


def crop_volume(volume, cropping_margin=None, cropping_shape=None, aff=None, return_crop_idx=False, mode='center'):
    """Crop volume by a given margin, or to a given shape.
    :param volume: 2d or 3d numpy array (possibly with multiple channels)
    :param cropping_margin: (optional) margin by which to crop the volume. The cropping margin is applied on both sides.
    Can be an int, sequence or 1d numpy array of size n_dims. Should be given if cropping_shape is None.
    :param cropping_shape: (optional) shape to which the volume will be cropped. Can be an int, sequence or 1d numpy
    array of size n_dims. Should be given if cropping_margin is None.
    :param aff: (optional) affine matrix of the input volume.
    If not None, this function also returns an updated version of the affine matrix for the cropped volume.
    :param return_crop_idx: (optional) whether to return the cropping indices used to crop the given volume.
    :param mode: (optional) if cropping_shape is not None, whether to extract the centre of the image (mode='center'),
    or to randomly crop the volume to the provided shape (mode='random'). Default is 'center'.
    :return: cropped volume, corresponding affine matrix if aff is not None, and cropping indices if return_crop_idx is
    True (in that order).
    """

    assert (cropping_margin is not None) | (cropping_shape is not None), \
        'cropping_margin or cropping_shape should be provided'
    assert not ((cropping_margin is not None) & (cropping_shape is not None)), \
        'only one of cropping_margin or cropping_shape should be provided'

    # get info
    new_volume = volume.copy()
    vol_shape = new_volume.shape
    n_dims, _ = get_dims(vol_shape)

    # find cropping indices
    if cropping_margin is not None:
        cropping_margin = reformat_to_list(cropping_margin, length=n_dims)
        do_cropping = np.array(vol_shape[:n_dims]) > 2 * np.array(cropping_margin)
        min_crop_idx = [cropping_margin[i] if do_cropping[i] else 0 for i in range(n_dims)]
        max_crop_idx = [vol_shape[i] - cropping_margin[i] if do_cropping[i] else vol_shape[i] for i in range(n_dims)]
    else:
        cropping_shape = reformat_to_list(cropping_shape, length=n_dims)
        if mode == 'center':
            min_crop_idx = np.maximum([int((vol_shape[i] - cropping_shape[i]) / 2) for i in range(n_dims)], 0)
            max_crop_idx = np.minimum([min_crop_idx[i] + cropping_shape[i] for i in range(n_dims)],
                                      np.array(vol_shape)[:n_dims])
        elif mode == 'random':
            crop_max_val = np.maximum(np.array([vol_shape[i] - cropping_shape[i] for i in range(n_dims)]), 0)
            min_crop_idx = np.random.randint(0, high=crop_max_val + 1)
            max_crop_idx = np.minimum(min_crop_idx + np.array(cropping_shape), np.array(vol_shape)[:n_dims])
        else:
            raise ValueError('mode should be either "center" or "random", had %s' % mode)
    crop_idx = np.concatenate([np.array(min_crop_idx), np.array(max_crop_idx)])

    # crop volume
    if n_dims == 2:
        new_volume = new_volume[crop_idx[0]: crop_idx[2], crop_idx[1]: crop_idx[3], ...]
    elif n_dims == 3:
        new_volume = new_volume[crop_idx[0]: crop_idx[3], crop_idx[1]: crop_idx[4], crop_idx[2]: crop_idx[5], ...]

    # sort outputs
    output = [new_volume]
    if aff is not None:
        aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ np.array(min_crop_idx)
        output.append(aff)
    if return_crop_idx:
        output.append(crop_idx)
    return output[0] if len(output) == 1 else tuple(output)


def crop_volume_around_region(volume,
                              mask=None,
                              masking_labels=None,
                              threshold=0.1,
                              margin=0,
                              cropping_shape=None,
                              cropping_shape_div_by=None,
                              aff=None,
                              overflow='strict'):
    """Crop a volume around a specific region.
    This region is defined by a mask obtained by either:
    1) directly specifying it as input (see mask)
    2) keeping a set of label values if the volume is a label map (see masking_labels).
    3) thresholding the input volume (see threshold)
    The cropping region is defined by the bounding box of the mask, which we can further modify by either:
    1) extending it by a margin (see margin)
    2) providing a specific cropping shape, in this case the cropping region will be centered around the bounding box
    (see cropping_shape).
    3) extending it to a shape that is divisible by a given number. Again, the cropping region will be centered around
    the bounding box (see cropping_shape_div_by).
    Finally, if the size of the cropping region has been modified, and that this modified size overflows out of the
    image (e.g. because the center of the mask is close to the edge), we can either:
    1) stick to the valid image space (the size of the modified cropping region won't be respected)
    2) shift the cropping region so that it lies on the valid image space, and if it still overflows, then we restrict
    to the valid image space.
    3) pad the image with zeros, such that the cropping region is not ill-defined anymore.
    3) shift the cropping region to the valida image space, and if it still overflows, then we pad with zeros.
    :param volume: a 2d or 3d numpy array
    :param mask: (optional) mask of region to crop around. Must be same size as volume. Can either be boolean or 0/1.
    If no mask is given, it will be computed by either thresholding the input volume or using masking_labels.
    :param masking_labels: (optional) if mask is None, and if the volume is a label map, it can be cropped around a
    set of labels specified in masking_labels, which can either be a single int, a sequence or a 1d numpy array.
    :param threshold: (optional) if mask amd masking_labels are None, lower bound to determine values to crop around.
    :param margin: (optional) add margin around mask
    :param cropping_shape: (optional) shape to which the input volumes must be cropped. Volumes are padded around the
    centre of the above-defined mask is they are too small for the given shape. Can be an integer or sequence.
    Cannot be given at the same time as margin or cropping_shape_div_by.
    :param cropping_shape_div_by: (optional) makes sure the shape of the cropped region is divisible by the provided
    number. If it is not, then we enlarge the cropping area. If the enlarged area is too big for the input volume, we
    pad it with 0. Must be an integer. Cannot be given at the same time as margin or cropping_shape.
    :param aff: (optional) if specified, this function returns an updated affine matrix of the volume after cropping.
    :param overflow: (optional) how to proceed when the cropping region overflows outside the initial image space.
    Can either be 'strict' (default), 'shift-strict', 'padding', 'shift-padding.
    :return: the cropped volume, the cropping indices (in the order [lower_bound_dim_1, ..., upper_bound_dim_1, ...]),
    and the updated affine matrix if aff is not None.
    """

    assert not ((margin > 0) & (cropping_shape is not None)), "margin and cropping_shape can't be given together."
    assert not ((margin > 0) & (cropping_shape_div_by is not None)), \
        "margin and cropping_shape_div_by can't be given together."
    assert not ((cropping_shape_div_by is not None) & (cropping_shape is not None)), \
        "cropping_shape_div_by and cropping_shape can't be given together."

    new_vol = volume.copy()
    n_dims, n_channels = get_dims(new_vol.shape)
    vol_shape = np.array(new_vol.shape[:n_dims])

    # mask ROIs for cropping
    if mask is None:
        if masking_labels is not None:
            _, mask = mask_label_map(new_vol, masking_values=masking_labels, return_mask=True)
        else:
            mask = new_vol > threshold

    # find cropping indices
    if np.any(mask):

        indices = np.nonzero(mask)
        min_idx = np.array([np.min(idx) for idx in indices])
        max_idx = np.array([np.max(idx) for idx in indices])
        intermediate_vol_shape = max_idx - min_idx

        if (margin == 0) & (cropping_shape is None) & (cropping_shape_div_by is None):
            cropping_shape = intermediate_vol_shape
        if margin:
            cropping_shape = intermediate_vol_shape + 2 * margin
        elif cropping_shape is not None:
            cropping_shape = np.array(reformat_to_list(cropping_shape, length=n_dims))
        elif cropping_shape_div_by is not None:
            cropping_shape = [find_closest_number_divisible_by_m(s, cropping_shape_div_by, answer_type='higher')
                              for s in intermediate_vol_shape]

        min_idx = min_idx - np.int32(np.ceil((cropping_shape - intermediate_vol_shape) / 2))
        max_idx = max_idx + np.int32(np.floor((cropping_shape - intermediate_vol_shape) / 2))
        min_overflow = np.abs(np.minimum(min_idx, 0))
        max_overflow = np.maximum(max_idx - vol_shape, 0)

        if 'strict' in overflow:
            min_overflow = np.zeros_like(min_overflow)
            max_overflow = np.zeros_like(min_overflow)

        if overflow == 'shift-strict':
            min_idx -= max_overflow
            max_idx += min_overflow

        if overflow == 'shift-padding':
            for ii in range(n_dims):
                # no need to do anything if both min/max_overflow are 0 (no padding/shifting required at all)
                # or if both are positive, because in this case we don't shift at all and we pad directly
                if (min_overflow[ii] > 0) & (max_overflow[ii] == 0):
                    max_idx_new = max_idx[ii] + min_overflow[ii]
                    if max_idx_new <= vol_shape[ii]:
                        max_idx[ii] = max_idx_new
                        min_overflow[ii] = 0
                    else:
                        min_overflow[ii] = min_overflow[ii] - (vol_shape[ii] - max_idx[ii])
                        max_idx[ii] = vol_shape[ii]
                elif (min_overflow[ii] == 0) & (max_overflow[ii] > 0):
                    min_idx_new = min_idx[ii] - max_overflow[ii]
                    if min_idx_new >= 0:
                        min_idx[ii] = min_idx_new
                        max_overflow[ii] = 0
                    else:
                        max_overflow[ii] = max_overflow[ii] - min_idx[ii]
                        min_idx[ii] = 0

        # crop volume if necessary
        min_idx = np.maximum(min_idx, 0)
        max_idx = np.minimum(max_idx, vol_shape)
        cropping = np.concatenate([min_idx, max_idx])
        if np.any(cropping[:3] > 0) or np.any(cropping[3:] != vol_shape):
            if n_dims == 3:
                new_vol = new_vol[cropping[0]:cropping[3], cropping[1]:cropping[4], cropping[2]:cropping[5], ...]
            elif n_dims == 2:
                new_vol = new_vol[cropping[0]:cropping[2], cropping[1]:cropping[3], ...]
            else:
                raise ValueError('cannot crop volumes with more than 3 dimensions')

        # pad volume if necessary
        if np.any(min_overflow > 0) | np.any(max_overflow > 0):
            pad_margins = tuple([(min_overflow[i], max_overflow[i]) for i in range(n_dims)])
            pad_margins = tuple(list(pad_margins) + [(0, 0)]) if n_channels > 1 else pad_margins
            new_vol = np.pad(new_vol, pad_margins, mode='constant', constant_values=0)

    # if there's nothing to crop around, we return the input as is
    else:
        min_idx = min_overflow = np.zeros(3)
        cropping = None

    # return results
    if aff is not None:
        if n_dims == 2:
            min_idx = np.append(min_idx, 0)
            min_overflow = np.append(min_overflow, 0)
        aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ min_idx
        aff[:-1, -1] = aff[:-1, -1] - aff[:-1, :-1] @ min_overflow
        return new_vol, cropping, aff
    else:
        return new_vol, cropping


def crop_volume_with_idx(volume, crop_idx, aff=None, n_dims=None, return_copy=True):
    """Crop a volume with given indices.
    :param volume: a 2d or 3d numpy array
    :param crop_idx: cropping indices, in the order [lower_bound_dim_1, ..., upper_bound_dim_1, ...].
    Can be a list or a 1d numpy array.
    :param aff: (optional) if aff is specified, this function returns an updated affine matrix of the volume after
    cropping.
    :param n_dims: (optional) number of dimensions (excluding channels) of the volume. If not provided, n_dims will be
    inferred from the input volume.
    :param return_copy: (optional) whether to return the original volume or a copy. Default is copy.
    :return: the cropped volume, and the updated affine matrix if aff is not None.
    """

    # get info
    new_volume = volume.copy() if return_copy else volume
    n_dims = int(np.array(crop_idx).shape[0] / 2) if n_dims is None else n_dims

    # crop image
    if n_dims == 2:
        new_volume = new_volume[crop_idx[0]:crop_idx[2], crop_idx[1]:crop_idx[3], ...]
    elif n_dims == 3:
        new_volume = new_volume[crop_idx[0]:crop_idx[3], crop_idx[1]:crop_idx[4], crop_idx[2]:crop_idx[5], ...]
    else:
        raise Exception('cannot crop volumes with more than 3 dimensions')

    if aff is not None:
        aff[0:3, -1] = aff[0:3, -1] + aff[:3, :3] @ crop_idx[:3]
        return new_volume, aff
    else:
        return new_volume


def resample_volume(volume, aff, new_vox_size, interpolation='linear', blur=True):
    """This function resizes the voxels of a volume to a new provided size, while adjusting the header to keep the RAS
    :param volume: a numpy array
    :param aff: affine matrix of the volume
    :param new_vox_size: new voxel size (3 - element numpy vector) in mm
    :param interpolation: (optional) type of interpolation. Can be 'linear' or 'nearest'. Default is 'linear'.
    :param blur: (optional) whether to blur before resampling to avoid aliasing effects.
    Only used if the input volume is downsampled. Default is True.
    :return: new volume and affine matrix
    """

    pixdim = np.sqrt(np.sum(aff * aff, axis=0))[:-1]
    new_vox_size = np.array(new_vox_size)
    factor = pixdim / new_vox_size
    sigmas = 0.25 / factor
    sigmas[factor > 1] = 0  # don't blur if upsampling

    volume_filt = gaussian_filter(volume, sigmas) if blur else volume

    # volume2 = zoom(volume_filt, factor, order=1, mode='reflect', prefilter=False)
    x = np.arange(0, volume_filt.shape[0])
    y = np.arange(0, volume_filt.shape[1])
    z = np.arange(0, volume_filt.shape[2])

    my_interpolating_function = RegularGridInterpolator((x, y, z), volume_filt, method=interpolation)

    start = - (factor - 1) / (2 * factor)
    step = 1.0 / factor
    stop = start + step * np.ceil(volume_filt.shape * factor)

    xi = np.arange(start=start[0], stop=stop[0], step=step[0])
    yi = np.arange(start=start[1], stop=stop[1], step=step[1])
    zi = np.arange(start=start[2], stop=stop[2], step=step[2])
    xi[xi < 0] = 0
    yi[yi < 0] = 0
    zi[zi < 0] = 0
    xi[xi > (volume_filt.shape[0] - 1)] = volume_filt.shape[0] - 1
    yi[yi > (volume_filt.shape[1] - 1)] = volume_filt.shape[1] - 1
    zi[zi > (volume_filt.shape[2] - 1)] = volume_filt.shape[2] - 1

    xig, yig, zig = np.meshgrid(xi, yi, zi, indexing='ij', sparse=True)
    volume2 = my_interpolating_function((xig, yig, zig))

    aff2 = aff.copy()
    for c in range(3):
        aff2[:-1, c] = aff2[:-1, c] / factor[c]
    aff2[:-1, -1] = aff2[:-1, -1] - np.matmul(aff2[:-1, :-1], 0.5 * (factor - 1))

    return volume2, aff2


def get_ras_axes(aff, n_dims=3):
    """This function finds the RAS axes corresponding to each dimension of a volume, based on its affine matrix.
    :param aff: affine matrix Can be a 2d numpy array of size n_dims*n_dims, n_dims+1*n_dims+1, or n_dims*n_dims+1.
    :param n_dims: number of dimensions (excluding channels) of the volume corresponding to the provided affine matrix.
    :return: two numpy 1d arrays of length n_dims, one with the axes corresponding to RAS orientations,
    and one with their corresponding direction.
    """
    aff_inverted = np.linalg.inv(aff)
    img_ras_axes = np.argmax(np.absolute(aff_inverted[0:n_dims, 0:n_dims]), axis=0)
    for i in range(n_dims):
        if i not in img_ras_axes:
            unique, counts = np.unique(img_ras_axes, return_counts=True)
            incorrect_value = unique[np.argmax(counts)]
            img_ras_axes[np.where(img_ras_axes == incorrect_value)[0][-1]] = i

    return img_ras_axes


def align_volume_to_ref(volume, aff, aff_ref=None, return_aff=False, n_dims=None, return_copy=True):
    """This function aligns a volume to a reference orientation (axis and direction) specified by an affine matrix.
    :param volume: a numpy array
    :param aff: affine matrix of the floating volume
    :param aff_ref: (optional) affine matrix of the target orientation. Default is identity matrix.
    :param return_aff: (optional) whether to return the affine matrix of the aligned volume
    :param n_dims: (optional) number of dimensions (excluding channels) of the volume. If not provided, n_dims will be
    inferred from the input volume.
    :param return_copy: (optional) whether to return the original volume or a copy. Default is copy.
    :return: aligned volume, with corresponding affine matrix if return_aff is True.
    """

    # work on copy
    new_volume = volume.copy() if return_copy else volume
    aff_flo = aff.copy()

    # default value for aff_ref
    if aff_ref is None:
        aff_ref = np.eye(4)

    # extract ras axes
    if n_dims is None:
        n_dims, _ = get_dims(new_volume.shape)
    ras_axes_ref = get_ras_axes(aff_ref, n_dims=n_dims)
    ras_axes_flo = get_ras_axes(aff_flo, n_dims=n_dims)

    # align axes
    aff_flo[:, ras_axes_ref] = aff_flo[:, ras_axes_flo]
    for i in range(n_dims):
        if ras_axes_flo[i] != ras_axes_ref[i]:
            new_volume = np.swapaxes(new_volume, ras_axes_flo[i], ras_axes_ref[i])
            swapped_axis_idx = np.where(ras_axes_flo == ras_axes_ref[i])
            ras_axes_flo[swapped_axis_idx], ras_axes_flo[i] = ras_axes_flo[i], ras_axes_flo[swapped_axis_idx]

    # align directions
    dot_products = np.sum(aff_flo[:3, :3] * aff_ref[:3, :3], axis=0)
    for i in range(n_dims):
        if dot_products[i] < 0:
            new_volume = np.flip(new_volume, axis=i)
            aff_flo[:, i] = - aff_flo[:, i]
            aff_flo[:3, 3] = aff_flo[:3, 3] - aff_flo[:3, i] * (new_volume.shape[i] - 1)

    if return_aff:
        return new_volume, aff_flo
    else:
        return new_volume


def mask_label_map(labels, masking_values, masking_value=0, return_mask=False):
    """
    This function masks a label map around a list of specified values.
    :param labels: input label map
    :param masking_values: list of values to mask around
    :param masking_value: (optional) value to mask the label map with
    :param return_mask: (optional) whether to return the applied mask
    :return: the masked label map, and the applied mask if return_mask is True.
    """

    # build mask and mask labels
    mask = np.zeros(labels.shape, dtype=bool)
    masked_labels = labels.copy()
    for value in reformat_to_list(masking_values):
        mask = mask | (labels == value)
    masked_labels[np.logical_not(mask)] = masking_value

    if return_mask:
        mask = mask * 1
        return masked_labels, mask
    else:
        return masked_labels


def get_largest_connected_component(mask, structure=None):
    """Function to get the largest connected component for a given input.
    :param mask: a 2d or 3d label map of boolean type.
    :param structure: numpy array defining the connectivity.
    """
    components, n_components = scipy_label(mask, structure)
    return components == np.argmax(np.bincount(components.flat)[1:]) + 1 if n_components > 0 else mask.copy()


# execute script
if __name__ == '__main__':
    main()
