#!/usr/bin/python3

import os
import numpy as np
import argparse
import surfa as sf


# parse commandline args
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--invol', required=True, help='input MRI volume')
parser.add_argument('-o', '--outvol', required=True, help='output MRI volume')
parser.add_argument('--hemi', required=True, help='hemi to process')
parser.add_argument('--pred', help='write prediction volume')
parser.add_argument('--norm', help='write normalized volume')
parser.add_argument('--fv', help='bring up freeview to show results', action='store_true')
parser.add_argument('--uthresh', default=5000, type=float, help='specify threshold to erase above')
parser.add_argument('--border', default=4, type=int, help='number of border voxels to set threshold at')
parser.add_argument('--multichannel', action='store_true', help='specify that data has multiple channels')
parser.add_argument('--model', help='use alternative model file')
parser.add_argument('--wts', help='wt filename')
parser.add_argument('--gpu', help='GPU number - if not supplied, CPU is used')
args = parser.parse_args()

# delay slow TF import after parsing cmd line
import tensorflow as tf
import neurite as ne

# check hemi
if args.hemi not in ('lh', 'rh'):
    sf.system.fatal(f'Hemi specification must be either `lh` or `rh`. User provided `{args.hemi}`.')

# read input volume and normalize input intensities
mri_in = sf.load_volume(args.invol)
if args.uthresh:
    mri_in[mri_in > args.uthresh] = 0

mri_in = mri_in - mri_in.min()
mri_in = (mri_in / mri_in.percentile(99)).clip(0, 1)
vol_shape = [192,160,192]
vol_shape = [192,160,160]
vol_shape = [256] * 3

mri_conf = mri_in.conform(shape=vol_shape, voxsize=1, orientation='LIA')

if args.norm:
    mri_in.save('norm.mgz')

# device handling, model reading and prediction
if args.gpu:
    device, ngpus = ne.tf.utils.setup_device(args.gpu)
else:
    device = '/cpu:0'

# model weights
if args.model is not None:
    modelfile = args.model
    print('Using custom model weights')
else:
    fshome = os.environ.get('FREESURFER_HOME')
    if fshome is None:
        sf.system.fatal('FREESURFER_HOME env variable must be set! Make sure FreeSurfer is properly sourced.')
    modelfile = os.path.join(fshome, 'models', f'exvivo.strip.{args.hemi}.h5')

print(f'using device {device}')
with tf.device(device):
    print(f'loading model from {modelfile}')
    model = tf.keras.models.load_model(modelfile, compile=False)
    if args.wts:
        print(f'loading weights from {args.wts}')
        model.load_weights(args.wts)

    pred = model.predict(mri_conf.data[np.newaxis, ..., np.newaxis]).squeeze()

mri_mask = mri_conf.new(pred)
if args.pred:
    mri_mask.save(args.pred)

mri_out = mri_mask.resample_like(mri_in)
mri_interior = mri_out.new(np.where(mri_out <= args.border, 1, 0))

# find largest connected component and use it as the mask
from skimage import measure
cc = measure.label(mri_interior.data, background=0)
cc_areas = [(cc == label).sum() for label in range(1,cc.max()+1)]
max_label = np.argmax(np.array(cc_areas))+1
mri_out = mri_out.new(mri_in * (cc == max_label))

mri_out.save(args.outvol)

if args.fv:
    fv = sf.vis.Freeview()
    fv.add_image(mri_in, name='input:linked=1')
    fv.add_image(mri_conf, name='conf', visible=0, linked=1, locked=1)
    fv.add_image(mri_interior, name='interior', colormap='lut', visible=0, linked=1, locked=1)
    fv.add_image(mri_out, name='out', heatscale='1,3', visible=1, linked=1)
    fv.show()
