#!/usr/bin/env python3

import os
import argparse
import numpy as np
import surfa as sf


# parse command line

parser = argparse.ArgumentParser()
parser.add_argument('-t', '--table', metavar='FILE', help='Input table.', required=True)
parser.add_argument('-o', '--out', metavar='FILE', help='Output map.', required=True)
parser.add_argument('-s', '--seg', metavar='FILE', help='Segmentation to map to.')
parser.add_argument('-p', '--parc', metavar='FILE', help='Parcellation to map to.')
parser.add_argument('-c', '--columns', nargs='*', help='Table columns to map. All are included by default.')
parser.add_argument('-l', '--lut', metavar='FILE', help='Alternative lookup table.')
args = parser.parse_args()

# sanity checks on the inputs
if args.seg and args.parc:
    sf.system.fatal('Must provide only one of --seg or --parc input.')

if args.seg is None and args.parc is None:
    sf.system.fatal('Must provide either --seg or --parc input.')

# columns to extract
columns = args.columns

# read the input table
table = {}
with open(args.table, 'r') as file:
    lines = file.read().splitlines()
    
    # get table header and build a mapping of the columns to extract
    header = lines[0].split()[1:]
    if columns is None:
        columns = header
    else:
        for col in columns:
            if col not in header:
                sf.system.fatal(f'Column "{col}" is not in table.')
    column_mapping = [header.index(col) for col in columns]

    # pull the data
    for line in lines[1:]:
        if not line:
            continue
        items = line.split()
        table[items[0]] = np.array([float(n) for n in items[1:]])[column_mapping]

# remove these unneeded rows if they exist
table.pop('eTIV', None)
table.pop('BrainSegVolNotVent', None)


# function for extracting a label index by name
def find_label_index(labels, label):

    # simple search - this will likely fail for any surface-based labels
    index = labels.search(label, exact=True)
    if index is not None:
        return index

    # prune out common metrics that unfortunately are added to the structure name
    prune_list = ('_area', '_volume', '_thickness', '_thicknessstd',
                  '_thickness.T1', '_meancurv', '_gauscurv', '_foldind', '_curvind')

    pruned_label = label
    for key in prune_list:
        pruned_label = pruned_label.replace(key, '')

    index = labels.search(pruned_label, exact=True)
    if index is not None:
        return index

    # if that didn't work it's probably because ctx needs to be added to the prefix
    # unfortunately there's no consistent syntax across parcellations, so this is a complete mess
    for key in ('lh', 'rh'):

        for mod in ('_', '-'):
            cortex_label = pruned_label.replace(f'{key}_', f'ctx{mod}{key}{mod}')
            index = labels.search(cortex_label, exact=True)
            if index is not None:
                return index

        cortex_label = pruned_label.replace(f'{key}_', '')
        index = labels.search(cortex_label, exact=True)
        if index is not None:
            return index

    return None


# load the inputs and prepare output map
if args.seg:
    input_seg = sf.load_volume(args.seg)
    map_image = input_seg.new(np.zeros((*input_seg.shape, len(columns))))
else:
    input_seg = sf.load_overlay(args.parc)
    map_image = input_seg.new(np.zeros((input_seg.shape[0], len(columns))))

# load the appropriate lookup table
if args.lut:
    labels = sf.load_label_lookup(args.lut)
elif args.seg:
    default = os.path.join(os.environ['FREESURFER_HOME'], 'FreeSurferColorLUT.txt')
    labels = sf.load_label_lookup(default)
else:
    labels = input_seg.labels

# match up each label name with an index and rasterize the mapping
for label, values in table.items():

    # there are some more non-structure metrics that might be hiding around
    if 'SurfArea' in label:
        continue

    index = find_label_index(labels, label)
    if index is None:
        print(f'warning: {label} does not exist in lookup table.')
        continue

    map_image[input_seg == index] = values

map_image.save(args.out)
