#!/usr/bin/python3

# This is a script that will print out information from a comma
# separated value file (CSV). It is assumed that the first row has
# field names. To run, pass it the csv file name and the field(s) that
# your want to print out.
#
# There are also some ADNI-specific flags:
# --rid RID: only print out rows that match the given RID. Multiple RIDs are ok.
# --s subject : convert FreeSurfer subject name into and RID. this only requires
#   the removal of any leading 0s (eg, 0041 becomes 41). Multiple subjects are ok.
# --v viscode : only print out rows that match the given VISCODE (visit code)
#
# $Id$

import sys
import os
import csv

#---------------------------------------------------------
def print_help():
  print("USAGE: csvprint")
  print("  --csv csvfile (can be a tsv file)")
  print("  --f field1 <field2...>")
  print("  --rid rid1 <rid2...> : match RID for ADNI")
  print("  --s subjid1 <subjid2...> : create RID from FS subject name (leading 0s)")
  print("  --v viscode <viscode> : match VISCODE for ADNI")
  print("  --l Label : match Label for GSP")
  return 0
#end def print_help:

#---------------------------------------------------------
def argnerr(narg,flag):
  print("ERROR: flag %s requires %d arguments" % (flag,narg))
  sys.exit(1)
#end def parse_args(argv)

#---------------------------------------------------------
def parse_args(argv):
  global csvfile,fields,ridlist,vlist,labellist,delimiter
  global debug

  del argv[0] # get past program name (like shift)

  while(len(argv) != 0):
    flag = argv[0]
    del argv[0]
    if(debug): print("flag = %s" % flag)

    if(flag == "--csv"):
      if(len(argv) < 1): argnerr(1,flag)
      csvfile = argv[0]
      del argv[0]
    elif(flag == "--field" or flag == "--f"):
      if(len(argv) < 1): argnerr(1,flag)
      while(len(argv) > 0 and not isflag(argv[0])):
        fields.append(argv[0])
        del argv[0]
    elif(flag == "--rid" or flag == "--RID"):
      if(len(argv) < 1): argnerr(1,flag)
      while(len(argv) > 0 and not isflag(argv[0])):
        ridlist.append(argv[0])
        del argv[0]
    elif(flag == "--s"):
      if(len(argv) < 1): argnerr(1,flag)
      while(len(argv) > 0 and not isflag(argv[0])):
        ridlist.append("%s" % int(argv[0]))
        del argv[0]
    elif(flag == "--v" or flag == "--viscode"):
      if(len(argv) < 1): argnerr(1,flag)
      while(len(argv) > 0 and not isflag(argv[0])):
        vlist.append(argv[0])
        del argv[0]
    elif(flag == "--l" or flag == "--label"):
      if(len(argv) < 1): argnerr(1,flag)
      while(len(argv) > 0 and not isflag(argv[0])):
        labellist.append(argv[0])
        del argv[0]
    elif(flag == "--tsv"): 
      delimiter="\t";
    elif(flag == "--debug"):
      debug = 1
    else:
      print("ERROR: flag %s not recognized" % flag)
      sys.exit(1)
    #endif
  #endwhile
  return 0
#end def parse_args(argv)

#---------------------------------------------------------
def isflag(arg):
  if(len(arg) < 3): return 0;
  if(arg[0] == "-" and arg[1] == "-"): return 1
  return 0
# end def isflag(arg)

#---------------------------------------------------------
def check_args():
  global csvfile, fields

  if(len(csvfile) == 0):
    print("ERROR: csv file needed")
    sys.exit(1)
  #endif    
  if(len(fields) == 0):
    print("ERROR: field needed")
    sys.exit(1)
  #endif    
  return 0
#end check_args()

#-----------------------------------------------------------
# ------ main -----------------------------------------------
#-----------------------------------------------------------

debug = 0
csvfile = ()
fields  = []
ridlist = []
vlist = []
labellist = []
delimiter=",";

nargs = len(sys.argv) - 1
if(nargs == 0):
  print_help()
  sys.exit(0)
#end
parse_args(sys.argv)
check_args()

filename, file_extension = os.path.splitext(csvfile);
if(file_extension == ".csv"): delimiter=",";
if(file_extension == ".tsv"): delimiter="\t";

#print("delimiter ##%s##" % delimiter)
#print "csv file is %s" % csvfile;
#print "fields is %s" % fields;
csv = csv.reader(open(csvfile, 'r'), delimiter=delimiter, quotechar='"')

m = 0
for row in csv:

  if(m == 0):
    # First row, parse the field names
    m = m + 1
    # Get index of RID if needed
    if(len(ridlist)>0):
      if(not "RID" in row):
        print("ERROR: cannot find field RID in %s" % (row))
        sys.exit(1)
      #endif
      ridindex = row.index("RID")
    #endif
    # Get index of VISCODE if needed
    if(len(vlist)>0):
      if(not "VISCODE" in row):
        print("ERROR: cannot find field VISCODE in %s" % (row))
        sys.exit(1)
      #endif
      vindex = row.index("VISCODE")
    #endif
    if(len(labellist)>0):
      if(not "Label" in row):
        print("ERROR: cannot find field Label in %s" % (row))
        sys.exit(1)
      #endif
      labelindex = row.index("Label")
    #endif
    # Get indices of fields
    ind = []
    for field in fields:
      if(field in row):
        ind.append(row.index(field))
      else:
        print('')
        print("ERROR: cannot find field %s in %s" % (field,row))
        print('')
        sys.exit(1)
      #endif
    #end
    continue
  #end if(m == 0):

  # If RID is not in RID list, just skip this input
  if(len(ridlist)>0):
    if(not row[ridindex] in ridlist): continue

  # If this VISCODE is not in VISCODE list, just skip this input
  if(len(vlist)>0):
    if(not row[vindex] in vlist): continue

  # If this VISCODE is not in VISCODE list, just skip this input
  if(len(labellist)>0):
    if(not row[labelindex] in labellist): continue

  # Skip missing data
  skip = 0
  for i in ind:
    if(len(row)<i):
      skip=1
      break
    if(len(row[i])==0):
      skip=1
      break
  # endfor
  if(skip): continue

  # Print out each field in field list
  for i in ind:
    if(len(row)<i): print("len(row)=%d, i=%d" % (len(row),i))
    sys.stdout.write ("%s " % (row[i]))
  # endfor
  print('')
  
#end for row in csv:

sys.exit(0)
#-------------------------------------------------#



