#!/usr/bin/python3

import os
import numpy as np
import argparse
import surfa as sf


# parse commandline args
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--invol', required=True, help='input MRI volume')
parser.add_argument('-o', '--outvol', required=True, help='output MRI volume')
parser.add_argument('--hemi', required=True, help='hemi to process')
parser.add_argument('--pred', help='write prediction volume')
parser.add_argument('--norm', help='write normalized volume')
parser.add_argument('--fv', help='bring up freeview to show results', action='store_true')
parser.add_argument('--norm_mean', help='normalize output mean to match input mean', action='store_true')
parser.add_argument('--write_rounds', help='write normalization results after each round', action='store_true')
parser.add_argument('--uthresh', type=float, default=5000, help='specify threshold to erase above')
parser.add_argument('--sigma', help='sigma to smooth bias field', default=.5, type=float)
parser.add_argument('--nrounds', default=1, type=int, help='number of rounds of iterative normalization to apply')
parser.add_argument('--multichannel', action='store_true', help='specify that data has multiple channels')
parser.add_argument('--model', help='use alternative model file')
parser.add_argument('--wts', help='alternative weights filename')
parser.add_argument('--gpu', help='GPU number - if not supplied, CPU is used')
args = parser.parse_args()

# delay slow TF import after parsing cmd line
import tensorflow as tf
import neurite as ne

# check hemi
if args.hemi not in ('lh', 'rh'):
    sf.system.fatal(f'Hemi specification must be either `lh` or `rh`. User provided `{args.hemi}`.')

# read input volume and normalize input intensities
mri_in = sf.load_volume(args.invol)
if args.uthresh:
    mri_in[mri_in > args.uthresh] = 0

mri_in = mri_in - mri_in.min()
mri_in = (mri_in / mri_in.percentile(99)).clip(0, 2)

if args.norm:
    mri_in.save('norm.mgz')

# device handling, model reading and prediction
if args.gpu:
    device, ngpus = ne.tf.utils.setup_device(args.gpu)
else:
    device = '/cpu:0'

layer_dict = {
    'Resize' : ne.layers.Resize,
    'GaussianBlur' : ne.layers.GaussianBlur
}

# model weights
if args.model is not None:
    modelfile = args.model
    print('Using custom model weights')
else:
    fshome = os.environ.get('FREESURFER_HOME')
    if fshome is None:
        sf.system.fatal('FREESURFER_HOME env variable must be set! Make sure FreeSurfer is properly sourced.')
    modelfile = os.path.join(fshome, 'models', f'exvivo.norm.{args.hemi}.h5')

norm_list = []
print(f'using device {device}')
for round in range(args.nrounds):
    print(f'Computing normalization for round {round}')
    with tf.device(device):
        if round == 0:
            print(f'loading model from {modelfile}')
            model_in = tf.keras.models.load_model(modelfile, compile=False, custom_objects=layer_dict)
            model = tf.keras.Model(model_in.inputs, [model_in.layers[-2].output])
            vol_shape = model.input.shape.as_list()[1:-1]
            mri_conf = mri_in.conform(shape=vol_shape, voxsize=1, orientation='LIA')

        if args.wts:
            print(f'loading weights from {args.wts}')
            model.load_weights(args.wts)

        bias_field = model.predict(mri_conf.data[np.newaxis, ..., np.newaxis])
        pad = 64

        # try to keep the means in brain about the same
        mri_bias = mri_conf.new(bias_field.squeeze())
        mri_norm = mri_conf * mri_bias

        # map estimated bias field back to unconformed space and apply it
        mri_bias_noconf = mri_bias.resample_like(mri_in).smooth(args.sigma)
        mri_out = mri_in * mri_bias_noconf
        if args.norm_mean:
            print('normalizing the output mean to match the input mean')
            mri_out *= mri_in.mean() / mri_out.mean()

        norm_list.append(mri_out.copy())
        if round < args.nrounds - 1:
            mri_conf = mri_out.conform(shape=vol_shape, voxsize=1, orientation='LIA')
            mri_conf = mri_conf - mri_conf.min()
            mri_conf = (mri_conf / mri_in.percentile(99)).clip(0, 2)

            if args.write_rounds:
                ext_ind = args.outvol.rfind('.mgz')
                if ext_ind < 0:
                    ext_ind = args.outvol.rfind('.nii.gz')

                tmp_fname = args.outvol[:ext_ind] + f'{round+1}.mgz'
                print(f'writing intermediate results to {tmp_fname}')
                mri_out.save(tmp_fname)

if args.pred:
    mri_norm.save(args.pred)

print(f'writing output volume to {args.outvol}')
mri_out.save(args.outvol)

if args.fv:
    print(f'displaying results in fv')
    fv = sf.vis.Freeview()
    fv.add_image(mri_in, name='input:linked=1')
    fv.add_image(mri_conf, name='conf', opts=':linked=1:visible=0:locked=1')
    fv.add_image(mri_bias_noconf, name='bias', opts=':colormap=heat:heatscale=1,3:heatscale_offset=1:linked=1:visible=0:locked=1')
    for ono, out_vol in enumerate(norm_list):
        fv.add_image(out_vol, name=f'out{ono}', opts=':heatscale=1,3:visible=1:linked=1')
    fv.add_image(mri_bias, name='bias_conf', opts=':colormap=heat:heatscale_offset=1:heatscale=0,.2:linked=1:visible=0:locked=1')
    fv.show(title=modelfile)
