#!/usr/bin/bash

## If you used the .rpm or .deb linux installer to install freesurfer under
## /usr/local, then preface this command with sudo as admin privs are needed.
##
## $ sudo FREESURFER_HOME=/usr/local/freesurfer/7.5.0 ./fs_install_cuda
##
## Otherwise, it should work to run w/o sudo privs.
## 
## $ FREESURFER_HOME=/usr/local/freesurfer/7.5.0 ./fs_install_cuda
##
## You will be prompted to confirm if you want modify the freesurfer python (fspython)
## installation to install the cuda version of torch along with other cuda python
## packages. So you can always abort if what the script reports for the install change
## does not look correct.

# check FS directory
if [[ -z "$FREESURFER_HOME" ]]; then
    echo "ERROR: must set FREESURFER_HOME before installing"
    echo "INFO: if you're using sudo, make sure to pass the FS home variable with"
    echo "INFO: sudo FREESURFER_HOME=\$FREESURFER_HOME fs_install_cuda $@"
    exit 1
fi

path_dev_base="$FREESURFER_HOME"
if [ ! -e $path_dev_base ]; then
   echo "Cannot find $path_dev_base  Please check that it exists - exiting."
   exit 1
fi

python_binary="${path_dev_base}/python/bin/python3"
if [ ! -e $python_binary ]; then
   echo "Cannot stat fspython binary = $python_binary - exiting."
   exit 1
fi

$python_binary -m pip freeze | grep torch | sed 's;^.*==;;' > /dev/null
if [ $? -ne 0 ]; then
   echo "pip freeze command failed - cannot continue - exiting."
   exit 1
fi

# exit 0

torch_rev=`$python_binary -m pip freeze | grep torch | sed 's;^.*==;;'` > /dev/null
torch_rev_numeric=`echo $torch_rev| sed 's;\+.*;;'`
if [[ "${torch_rev}" == "${torch_rev_numeric}+cpu" ]]; then
   ## remove cpu version and install non-cpu version which will also add other nvidia packages
   read -p "Replace torch ${torch_rev} with torch ${torch_rev_numeric} in ${path_dev_base}/python? (y/n): " -n 1 -r
   echo    # (optional) move to a new line
   if [[ ! $REPLY =~ ^[Yy]$ ]]; then
      echo "Install aborted - exiting."
      exit 0
   fi
   sleep 5
   ## echo "Replacing torch ${torch_rev} with torch ${torch_rev_numeric}"
   $python_binary -m pip uninstall -y torch
   if [ $? -ne 0 ]; then
      echo "Pip uninstall of torch command failed - cannot continue - exiting."
      echo "Try using sudo to run the command if you have not already done so."
      exit 1
   fi
   yes | $python_binary -m pip install --no-cache-dir torch==$torch_rev_numeric
   if [ $? -ne 0 ]; then
      echo "pip install of torch command failed - cannot continue - exiting."
      exit 1
   fi
   ## check there is now a libtorch_cuda.so
   path_torch_libcuda="$path_dev_base/python/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so"
   if [ ! -e $path_torch_libcuda ]; then
      echo "Cannot stat a libtorch_cuda.so = $path_torch_libcuda in fspython torch module"
      exit 1
   else
      echo "cuda install success"
   fi
elif [[ "${torch_rev}" == "${torch_rev_numeric}" ]]; then
   ## non-cpu version apparently already installed
   echo "Looks like non-cpu version of torch $torch_rev already installed in ${path_dev_base}/python - nothing to do, exiting."
   exit 0
else
   echo "Do not know how to process torch_rev $torch_rev"
   exit 1
fi

