#!/usr/bin/env python3

import os
import sys
import string


def print_usage():
  print("USAGE: slicedelay --help")
  print("  --o slicedelayfile")
  print("  --nslices nslices: total number of slices in the volume")
  print("  --order order (up,down,odd,even,siemens)")
  print("  --ngroups ngroups (number of slice groups for SMS)")
  return 0


def print_help():
  print_usage()
  print(
  '''
  Creates an FSL custom slice delay file for use with slicetimer (--tcustom=sdfile).
  It has a single column of values, one for each slice. Each value is the \n  slice delay measured as a fraction of the TR and range from +0.5 \n  (beginning of the TR) to -0.5 (end of the TR). Used for slice-time \n  correction of fMRI
  '''
  )


def argnerr(narg,flag):
  print(f"ERROR: flag {flag} requires {narg} arguments")
  sys.exit(1)


def parse_args(argv):
  del argv[0]; # get past program name (like shift)

  sdf = ""
  nslices = 0
  order = ""
  ngroups = 1
  debug = 0

  while(len(argv) != 0):
    flag = argv[0]
    del argv[0]
    if(debug): print(f"flag = {flag}")

    if(flag == "--o"):
      if(len(argv) < 1): argnerr(1,flag)
      sdf = argv[0]; del argv[0]
    elif(flag == "--nslices"):
      if(len(argv) < 1): argnerr(1,flag)
      nslices = int(argv[0]); del argv[0]
    elif(flag == "--order"):
      if(len(argv) < 1): argnerr(1,flag)
      order = argv[0]; del argv[0]
    elif(flag == "--ngroups"):
      if(len(argv) < 1): argnerr(1,flag)
      ngroups = int(argv[0]); del argv[0]
    elif(flag == "--up"):      order = "up"
    elif(flag == "--down"):    order = "down"
    elif(flag == "--odd"):     order = "odd"
    elif(flag == "--even"):    order = "even"
    elif(flag == "--siemens"): order = "siemens"
    elif(flag == "--help"):
      print_help()
      sys.exit(1)
    elif(flag == "--debug"):
      debug = 1
    else:
      print(f"ERROR: flag {flag} not recognized")
      sys.exit(1)

  return sdf, nslices, order, ngroups, debug


def isflag(arg):
  if(len(arg) < 3): return 0
  if(arg[0] == "-" and arg[1] == "-"): return 1
  return 0


def check_args(sdf, nslices, order):
  if(len(sdf) == 0):
    print("ERROR: output file needed")
    print_help()
    sys.exit(1)

  if(nslices == 0):
    print("ERROR: nslices needed")
    print_help()
    sys.exit(1)

  if(len(order) == 0):
    print("ERROR: order needed")
    print_help()
    sys.exit(1)
  
  return sdf, nslices, order


def generate_slice_delay_file(sdf, nslices, order, ngroups):
  err = nslices%ngroups
  if(err):
    print(f"ERROR: cannot divide nslices={nslices} by ngroups={ngroups}")
    sys.exit(0)

  # This ratio should always be an int. Should to an error check to assure
  nslicespg = int(nslices/ngroups)

  print(f"sdf {sdf}")
  print(f"nslices {nslices}")
  print(f"order {order}")
  print(f"ngroups {ngroups}")
  print(f"nslicespg {nslicespg}")

  AcqSliceDelay = []
  for sno in range(1,nslicespg+1):
    D = ((nslicespg-1.0)/2.0-(sno-1.0))/nslicespg
    AcqSliceDelay.append(D)

  print(AcqSliceDelay)

  AnatSliceDelay = []
  AnatSliceOrder = ()
  for sg in range(1,ngroups+1):
    print(f"sg {sg}")
    if(order == "up"):   AnatSliceOrder = list(range(1,nslicespg+1,+1))
    if(order == "down"): AnatSliceOrder = list(range(nslicespg,0,-1))
    if(order == "odd"):  AnatSliceOrder = list(range(1,nslicespg+1,2)) + list(range(2,nslicespg+1,2))
    if(order == "even"): AnatSliceOrder = list(range(2,nslicespg+1,2)) + list(range(1,nslicespg+1,2))
    if(order == "siemens"):
      if(nslicespg%2==1): # siemens-odd
        AnatSliceOrder = list(range(1, nslicespg+1, 2)) + list(range(2, nslicespg+1, 2))
      else: # siemens-even
        AnatSliceOrder = list(range(2, nslicespg+1, 2)) + list(range(1, nslicespg+1, 2))

    if(len(AnatSliceOrder) == 0):
      print(f"ERROR: slice order {order} not recognized")
      sys.exit(0)

    AnatAcqSliceOrder0 = sorted(range(len(AnatSliceOrder)), key=lambda x: AnatSliceOrder[x])
    for s in AnatAcqSliceOrder0:
      AnatSliceDelay.append(AcqSliceDelay[s])

  with open(sdf, 'w') as fp:
    for d in AnatSliceDelay:
      fp.write('%15.13f\n' % d)

  return 0

#-----------------------------------------------------------
# ------ main ----------------------------------------------
#-----------------------------------------------------------

def main(argv):
  sdf, nslices, order, ngroups, debug = parse_args(argv)
  sdf, nslices, order = check_args(sdf, nslices, order)
  generate_slice_delay_file(sdf, nslices, order, ngroups)
  sys.exit(0)

if __name__ == "__main__":
  main(sys.argv)
