#!/usr/bin/python3

import os
import sys
import numpy as np
import argparse
import surfa as sf

# parse args
parser = argparse.ArgumentParser()
parser.add_argument('-s', '--subjs', nargs='*', required=True, help='List of subjects to process.')
parser.add_argument('--hemi', required=True, help='Hemisphere to reconstruct (lh or rh).')
parser.add_argument('-d', '--sdir', help='Override SUBJECTS_DIR.')
parser.add_argument('-m', '--model', help='Override default model.')
parser.add_argument('-x', '--suffix', default='topofit', help='Suffix of output surface (default is \'topofit\').')
parser.add_argument('-g', '--gpu', action='store_true', help='Use the GPU.')
parser.add_argument('--rsi', action='store_true', help='Remove self-intersecting faces during the deformation.')
parser.add_argument('--single-iter', action='store_true', help='Prevent deformation steps from running more than once.')
parser.add_argument('--vol', default='norm.mgz', help='Subject volume to use as input.')
args = parser.parse_args()

# import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal

# if we can't import torch_scatter (hard to ship in FS)
# we can just use a custom (but ugly) scatter-max substitute
try:
    from torch_scatter import scatter_max
    have_torch_scatter = True
except:
    have_torch_scatter = False

# necessary for speed gains (I think)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True

# configure GPU device
if args.gpu:
    print('Configuring model on the GPU')
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    device = torch.device('cuda')
    device_name = 'GPU'
else:
    print('Configuring model on the CPU')
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    device = torch.device('cpu')
    device_name = 'CPU'

# sanity check on inputs
if args.hemi not in ('lh', 'rh'):
    sf.system.fatal("Hemi must be 'lh' or 'rh'.")

# check FS home
fshome = os.environ.get('FREESURFER_HOME')
if fshome is None:
    sf.system.fatal('FREESURFER_HOME env variable must be set! Make sure FreeSurfer is properly sourced.')
basedir = os.path.join(fshome, 'models', 'topofit')

# model file
if args.model is not None:
    modelfile = args.model
else:
    modelfile = os.path.join(basedir, f'topofit.{args.hemi}.1.pt')

# template files
avgfile = os.path.join(basedir, f'{args.hemi}.white.ctx.average.6')
icofile = os.path.join(basedir, 'ico.npz')
baryfile = os.path.join(basedir, 'bary.npz')
mappingfile = os.path.join(basedir, 'mapping.npz')


class ConvBlock(nn.Module):
    """
    Specific convolutional block followed by leakyrelu for unet.
    """

    def __init__(self, ndims, in_channels, out_channels, stride=1):
        super().__init__()

        Conv = getattr(nn, 'Conv%dd' % ndims)
        self.main = Conv(in_channels, out_channels, 3, stride, 1)
        self.activation = nn.LeakyReLU(0.2)

    def forward(self, x):
        out = self.main(x)
        out = self.activation(out)
        return out


class ImageUnet(nn.Module):
    """
    Volumetric image-based U-Net architecture.
    """

    def __init__(self,
                 nb_features,
                 nb_levels=None,
                 infeats=1,
                 max_pool=2,
                 feat_mult=1,
                 nb_conv_per_level=1,
                 ndims=3):

        super().__init__()

        # build feature list automatically
        if isinstance(nb_features, int):
            if nb_levels is None:
                raise ValueError('must provide unet nb_levels if nb_features is an integer')
            feats = np.round(nb_features * feat_mult ** np.arange(nb_levels)).astype(int)
            nb_features = [
                np.repeat(feats[:-1], nb_conv_per_level),
                np.repeat(np.flip(feats), nb_conv_per_level)
            ]
        elif nb_levels is not None:
            raise ValueError('cannot use nb_levels if nb_features is not an integer')

        # extract any surplus (full resolution) decoder convolutions
        enc_nf, dec_nf = nb_features
        nb_dec_convs = len(enc_nf)
        final_convs = dec_nf[nb_dec_convs:]
        dec_nf = dec_nf[:nb_dec_convs]
        self.nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1

        if isinstance(max_pool, int):
            max_pool = [max_pool] * self.nb_levels

        # cache downsampling / upsampling operations
        MaxPooling = getattr(nn, 'MaxPool%dd' % ndims)
        self.pooling = [MaxPooling(s) for s in max_pool]
        self.upsampling = [nn.Upsample(scale_factor=s, mode='nearest') for s in max_pool]

        # configure encoder (down-sampling path)
        prev_nf = infeats
        encoder_nfs = [prev_nf]
        self.encoder = nn.ModuleList()
        for level in range(self.nb_levels - 1):
            convs = nn.ModuleList()
            for conv in range(nb_conv_per_level):
                nf = enc_nf[level * nb_conv_per_level + conv]
                convs.append(ConvBlock(ndims, prev_nf, nf))
                prev_nf = nf
            self.encoder.append(convs)
            encoder_nfs.append(prev_nf)

        # configure decoder (up-sampling path)
        encoder_nfs = np.flip(encoder_nfs)
        self.decoder = nn.ModuleList()
        for level in range(self.nb_levels - 1):
            convs = nn.ModuleList()
            for conv in range(nb_conv_per_level):
                nf = dec_nf[level * nb_conv_per_level + conv]
                convs.append(ConvBlock(ndims, prev_nf, nf))
                prev_nf = nf
            self.decoder.append(convs)
            prev_nf += encoder_nfs[level]

        # now we take care of any remaining convolutions
        self.remaining = nn.ModuleList()
        for num, nf in enumerate(final_convs):
            self.remaining.append(ConvBlock(ndims, prev_nf, nf))
            prev_nf = nf

        # cache final number of features
        self.final_nf = prev_nf

    def forward(self, x):

        # encoder forward pass
        x_history = [x]
        for level, convs in enumerate(self.encoder):
            for conv in convs:
                x = conv(x)
            x_history.append(x)
            x = self.pooling[level](x)

        # decoder forward pass with upsampling and concatenation
        for level, convs in enumerate(self.decoder):
            for conv in convs:
                x = conv(x)
            x = self.upsampling[level](x)
            x = torch.cat([x, x_history.pop()], dim=1)

        # remaining convs at full resolution
        for conv in self.remaining:
            x = conv(x)

        return x


class DynamicGraphConv(torch.nn.Module):

    def __init__(self, in_channels, out_channels, mesh_info, bias=True, activation='leaky'):

        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edges_a = mesh_info['adj_edges_a']
        self.edges_b = mesh_info['adj_edges_b']
        self.weights = mesh_info['adj_weights']
        self.size = mesh_info['size']

        if activation == 'leaky':
            self.activation = nn.LeakyReLU(0.3)
        elif activation is None:
            self.activation = None
        else:
            raise ValueError(f'unknown activation `{activation}`.')

        self.conv1d = torch.nn.Conv1d(
            in_channels=in_channels * 2,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding='valid',
            bias=bias)

    def forward(self, input_features):

        vertices = input_features[self.edges_a]
        neighbors = input_features[self.edges_b]
        concat_features = torch.cat([vertices, neighbors - vertices], -1)

        concat_features = torch.unsqueeze(torch.swapaxes(concat_features, -2, -1), 0)
        edge_features = self.conv1d(concat_features)
        edge_features = torch.squeeze(edge_features, 0)

        edge_features = torch.swapaxes(edge_features, -2, -1)
        edge_features_weighted = edge_features * self.weights
        indices = self.edges_a.unsqueeze(-1).expand(-1, self.out_channels)
        features = torch.zeros((self.size, self.out_channels)).scatter_add(-2, indices, edge_features_weighted)

        # activation
        if self.activation is not None:
            features = self.activation(features)

        return features


class DeformationBlock(torch.nn.Module):

    def __init__(self,
                 mesh_level,
                 mesh_collection,
                 nb_input_features,
                 nb_features=64,
                 train_iters=1,
                 infer_iters=1,
                 unet_levels=1,
                 convs_per_unet_level=4,
                 remove_intersections=False):

        super().__init__()

        self.mesh_level = mesh_level
        self.mesh_collection = mesh_collection
        self.train_iters = train_iters
        self.infer_iters = infer_iters
        self.remove_intersections = remove_intersections

        # configure encoder (down-sampling path)
        curr_level = mesh_level
        prev_nf = nb_input_features
        encoder_nfs = [prev_nf]
        self.encoder = nn.ModuleList()
        for level in range(unet_levels):
            convs = nn.ModuleList()
            for conv in range(convs_per_unet_level):
                nf = nb_features
                convs.append(DynamicGraphConv(prev_nf, nf, self.mesh_collection[curr_level]))
                prev_nf = nf
            self.encoder.append(convs)
            if level < unet_levels - 1:
                encoder_nfs.append(prev_nf)
                curr_level -= 1

        # configure decoder (up-sampling path)
        encoder_nfs = np.flip(encoder_nfs)
        self.decoder = nn.ModuleList()
        for level in range(unet_levels - 1):
            curr_level += 1
            prev_nf += encoder_nfs[level]
            convs = nn.ModuleList()
            for conv in range(convs_per_unet_level):
                nf = nb_features
                convs.append(DynamicGraphConv(prev_nf, nf, self.mesh_collection[curr_level]))
                prev_nf = nf
            self.decoder.append(convs)

        # final conv to estimate mesh deformation
        self.finalconv = DynamicGraphConv(prev_nf, 3, self.mesh_collection[curr_level], activation=None)

    def forward(self, x):

        # encoder forward pass
        curr_level = self.mesh_level
        x_history = [x]
        for level, convs in enumerate(self.encoder):
            for conv in convs:
                x = conv(x)
            if level < len(self.encoder) - 1:
                x_history.append(x)
                x = pool(x, self.mesh_collection[curr_level])
                curr_level -= 1

        # decoder forward pass with upsampling and concatenation
        for level, convs in enumerate(self.decoder):
            curr_level += 1
            # print(f'enc conv level {level} | mesh {curr_level}')
            x = unpool(x, self.mesh_collection[curr_level])
            x = torch.cat([x, x_history.pop()], dim=-1)
            for conv in convs:
                x = conv(x)

        x = self.finalconv(x)
        return x


class SurfNet(nn.Module):

    def __init__(self, device):

        super().__init__()
        
        self.device = device
        self.config = get_network_config()

        self.image_unet = ImageUnet(self.config['unet_features'])
        self.include_vertex_properties = True
        self.scale_delta_prediction = 10.0

        nb_input_features = self.image_unet.final_nf
        if self.include_vertex_properties:
            nb_input_features += 6

        config_blocks = self.config['blocks']
        max_mesh_level = np.max([b['mesh_level'] for b in config_blocks])
        self.mesh_collection = load_mesh_collection_info(max_mesh_level, self.device)

        self.blocks = nn.ModuleList()
        for n, block in enumerate(config_blocks):
            block['mesh_collection'] = self.mesh_collection
            block['nb_input_features'] = nb_input_features
            self.blocks.append(DeformationBlock(**block))

    def forward(self, image, coords):

        image_size = torch.Tensor(list(image.shape[-3:])).to(self.device)

        # predict image-based features
        image_features = self.image_unet(image)

        # squeeze dimensions
        image_features = image_features[0, ...]
        coords = coords[0, ...]

        previous_mesh_level = None

        for blockno, block in enumerate(self.blocks):

            # upsample to the next mesh resolution (if necessary)
            if previous_mesh_level is not None and previous_mesh_level < block.mesh_level:
                indices, weights = self.mesh_collection[block.mesh_level]['upsampler']
                coords = torch.sum(coords[indices] * weights, -2)

            # get number of block iterations (usually 1 when training)
            iters = block.train_iters if self.training else block.infer_iters

            for it in range(iters):

                # sample image-based features at current mesh position
                sampled_features = point_sample(coords, image_features, image_size)

                # add vertex properties
                if self.include_vertex_properties:
                    scaled_coords = coords / torch.max(image_size)
                    normals = compute_normals(coords, self.mesh_collection[block.mesh_level]['faces'])
                    sampled_features = torch.cat([scaled_coords, normals, sampled_features], axis=-1)

                # predict and apply the deformation in mesh space
                deformation = block(sampled_features)

                if self.scale_delta_prediction is not None:
                    deformation = deformation * self.scale_delta_prediction

                coords = coords + deformation

                # remove intersections for the final deformation blocks
                if args.rsi and block.remove_intersections:
                    faces = self.mesh_collection[block.mesh_level]['faces'].cpu().numpy().copy()
                    mesh = sf.Mesh(coords.cpu().numpy(), faces)
                    mesh = sf.mesh.remove_self_intersections(mesh, max_attempts=20)
                    coords = torch.from_numpy(mesh.vertices).type(torch.float32).to(self.device)

                previous_mesh_level = block.mesh_level

        return coords


def cross(vector1, vector2, dim=-1):
    v1_x, v1_y, v1_z = torch.unbind(vector1, dim=dim)
    v2_x, v2_y, v2_z = torch.unbind(vector2, dim=dim)
    n_x = v1_y * v2_z - v1_z * v2_y
    n_y = v1_z * v2_x - v1_x * v2_z
    n_z = v1_x * v2_y - v1_y * v2_x
    return torch.stack([n_x, n_y, n_z], dim=dim)


def point_sample(coords, features, image_size, normed=False):
    if not normed:
        half_size = (image_size - 1) / 2
        coords = (coords - half_size) / half_size
    coords = torch.reshape(coords, (1, coords.shape[-2], 1, 1, coords.shape[-1]))
    point_features = torch.nn.functional.grid_sample(features.unsqueeze(0).swapaxes(-1, -3), coords, align_corners=True, mode='bilinear')
    point_features = point_features.squeeze(0).squeeze(-1).squeeze(-1).swapaxes(-1, -2)
    return point_features


def gather_faces(coords, face_indices):
    return coords[face_indices]


def face_normals(face_coords, clockwise=False, normalize=True):
    v0, v1, v2 = torch.unbind(face_coords, -2)
    normal_vector = cross(v1 - v0, v2 - v0, dim=-1)
    if not clockwise:
        normal_vector = -1.0 * normal_vector
    if normalize:
        normal_vector = torch.nn.functional.normalize(normal_vector, p=2, dim=-1)
    return normal_vector


def compute_normals(coords, face_indices):
    face_coords = gather_faces(coords, face_indices)
    mesh_face_normals = face_normals(face_coords, clockwise=False, normalize=False)

    unnorm_vertex_normals = torch.zeros(coords.shape)
    for i in range(3):
        unnorm_vertex_normals = unnorm_vertex_normals.scatter_add(-2, face_indices[..., i:i + 1].expand(-1, 3), mesh_face_normals)

    vector_norms = torch.sqrt(torch.sum(unnorm_vertex_normals ** 2, dim=-1, keepdims=True))
    return unnorm_vertex_normals / vector_norms


def pool(features, mesh_info):
    pooled = gather_vertex_features(features,
        mesh_info['pooling_size_a'],
        mesh_info['pooling_b'],
        mesh_info['pooling_a'],
        mesh_info.get('pooling_emg_a'))
    return pooled


def unpool(features, mesh_info):
    unpooled = gather_vertex_features(features,
        mesh_info['pooling_size_b'],
        mesh_info['pooling_a'],
        mesh_info['pooling_b'],
        mesh_info.get('pooling_emg_b'))
    return unpooled


def gather_vertex_features(features, size, sources, targets, emg):
    nb_features = features.shape[-1]
    gathered_features = features[sources]
    if have_torch_scatter:
        out = torch.zeros((size, nb_features), dtype=torch.float32, device=device) - 1000
        out, _ = scatter_max(gathered_features, targets, -2, out=out)
    else:
        out = gathered_features[emg].max(-2)[0]
    return out


def get_network_config(name=None):

    if name is None:

        i = 1 if args.single_iter else 2

        config = {
            'unet_features': [
                [16, 32, 32, 64],
                [64, 64, 64, 64, 64]],
            'coarse_training_skip_blocks': [7],
            'blocks': [
                {'mesh_level': 1, 'train_iters': 1, 'infer_iters': 1, 'unet_levels': 1, 'convs_per_unet_level': 3},
                {'mesh_level': 2, 'train_iters': 1, 'infer_iters': 1, 'unet_levels': 2, 'convs_per_unet_level': 2},
                {'mesh_level': 3, 'train_iters': 1, 'infer_iters': 1, 'unet_levels': 3, 'convs_per_unet_level': 2},
                {'mesh_level': 4, 'train_iters': 1, 'infer_iters': 1, 'unet_levels': 3, 'convs_per_unet_level': 2},
                {'mesh_level': 5, 'train_iters': 1, 'infer_iters': 1, 'unet_levels': 3, 'convs_per_unet_level': 2},
                {'mesh_level': 6, 'train_iters': 1, 'infer_iters': 1, 'unet_levels': 3, 'convs_per_unet_level': 2, 'remove_intersections': True},
                {'mesh_level': 6, 'train_iters': 1, 'infer_iters': 1, 'unet_levels': 3, 'convs_per_unet_level': 2, 'remove_intersections': True},
                {'mesh_level': 7, 'train_iters': 1, 'infer_iters': i, 'unet_levels': 3, 'convs_per_unet_level': 2, 'remove_intersections': True},
            ],
        }

    return config


def load_mesh_collection_info(max_mesh_level, device):
    return {level: load_mesh_info(level, device) for level in range(1, max_mesh_level + 1)}


def load_mesh_info(mesh_level, device):
    npz_ico = np.load(icofile, allow_pickle=True)
    npz_bary = np.load(baryfile, allow_pickle=True)

    mesh_info = {
        'size': len(npz_ico[f'ico-{mesh_level}-vertices']),
        'faces': torch.Tensor(npz_ico[f'ico-{mesh_level}-faces']).type(torch.int64).to(device),
        'adj_edges_a': torch.Tensor(npz_ico[f'ico-{mesh_level}-adjacency-indices'][:, 0]).type(torch.int64).to(device),
        'adj_edges_b': torch.Tensor(npz_ico[f'ico-{mesh_level}-adjacency-indices'][:, 1]).type(torch.int64).to(device),
        'adj_weights': torch.Tensor(np.expand_dims(npz_ico[f'ico-{mesh_level}-adjacency-values'], -1)).to(device),
        'upsampler': [
            torch.Tensor(npz_bary[f'ico-{mesh_level}-sources']).type(torch.int64).to(device),
            torch.Tensor(np.expand_dims(npz_bary[f'ico-{mesh_level}-bary'], -1)).to(device),
        ],
        'pooling_a': torch.Tensor(npz_ico[f'mapping-{mesh_level - 1}-to-{mesh_level}-indices'][:, 0]).type(torch.int64).to(device),
        'pooling_b': torch.Tensor(npz_ico[f'mapping-{mesh_level - 1}-to-{mesh_level}-indices'][:, 1]).type(torch.int64).to(device),
        'pooling_values': torch.Tensor(npz_ico[f'mapping-{mesh_level - 1}-to-{mesh_level}-values']).to(device),
        'pooling_size_a': npz_ico[f'mapping-{mesh_level - 1}-to-{mesh_level}-shape'][0],
        'pooling_size_b': npz_ico[f'mapping-{mesh_level - 1}-to-{mesh_level}-shape'][1],
    }

    if not have_torch_scatter:
        mesh_info['pooling_emg_a'] = torch.from_numpy(npz_ico[f'mapping-{mesh_level - 1}-to-{mesh_level}-emg-a']).type(torch.int64).to(device)
        mesh_info['pooling_emg_b'] = torch.from_numpy(npz_ico[f'mapping-{mesh_level - 1}-to-{mesh_level}-emg-b']).type(torch.int64).to(device)

    return mesh_info


# initialize model and load weights
with torch.no_grad():
    model = SurfNet(device).to(device)
    checkpoint = torch.load(modelfile, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()

# load surface template files
insize = (96, 144, 208)
target_faces = np.load(icofile)[f'ico-7-faces']
source_mapping = np.load(mappingfile, allow_pickle=True)[f'mapping-6-to-1']

template = sf.load_mesh(avgfile)
template_vertices = template.convert(space='vox').vertices
template_midpoint = np.mean([template_vertices.min(0), template_vertices.max(0)], 0)
template_vertices = template_vertices[source_mapping]

subjectsdir = args.sdir if args.sdir is not None else os.environ.get('SUBJECTS_DIR')
if subjectsdir is None:
    sf.system.fatal('Must set SUBJECTS_DIR or use --sdir flag.')

print(f'Using subject volume {args.vol}')

for subj in args.subjs:

    # load norm and talairach transform
    print(os.path.join(subjectsdir, subj, 'mri', args.vol))
    norm = sf.load_volume(os.path.join(subjectsdir, subj, 'mri', args.vol))
    trf = sf.load_affine(os.path.join(subjectsdir, subj, 'mri/transforms/talairach.lta'))

    # align template and build cropping
    corner = np.floor(trf.inv().transform(template_midpoint) - (np.array(insize) / 2)).astype('int')
    corner = np.clip(corner, 0, None)
    in_surf = (trf.inv().transform(template_vertices) - corner)[np.newaxis]
    cropping = tuple([slice(a, o) for a, o in zip(corner, corner + insize)])

    # prepare image
    cropped = (norm[cropping].astype(np.float32) / 180.0).clip(0, 1)
    if np.array_equal(cropped.baseshape, insize):
        in_image = cropped
    else:
        window = tuple([slice(0, s) for s in cropped.baseshape])
        in_image = np.zeros(insize, dtype='float32')
        in_image[window] = cropped.data
    in_image = in_image[np.newaxis, np.newaxis, ...]

    # predict
    with torch.no_grad():
        in_image = torch.from_numpy(in_image.astype(np.float32, copy=False))
        in_surf = torch.from_numpy(in_surf.astype(np.float32, copy=False))
        vertices = model(in_image, in_surf).cpu().numpy().squeeze()

    # transform back to image space
    surf = sf.Mesh(vertices, target_faces, space='vox', geometry=cropped).convert(space='surf', geometry=norm)

    # write file
    outfile = os.path.join(subjectsdir, subj, f'surf/{args.hemi}.white.{args.suffix}')
    surf.save(outfile)
    print(f'Wrote surface to: {outfile}')
