#!/usr/bin/env bash

# check if hwloc commands exist
declare -i HAS_HWLOC=0
type hwloc-bind >/dev/null 2>&1
HAS_HWLOC="${HAS_HWLOC} + $?"

type hwloc-distrib >/dev/null 2>&1
HAS_HWLOC="${HAS_HWLOC} + $?"

type hwloc-ls >/dev/null 2>&1
HAS_HWLOC="${HAS_HWLOC} + $?"

type hwloc-calc >/dev/null 2>&1
HAS_HWLOC="${HAS_HWLOC} + $?"

type hwloc-ps >/dev/null 2>&1
HAS_HWLOC="${HAS_HWLOC} + $?"


#parse args
declare -a UNKNOWN_ARGS=()
declare -i DISTRIBUTE=1
declare -i INDEX=0
PROC_BIND="all"
CURRENT_CPUSET=""
OPENMP_VERSION=4.0
OPENMP_PROC_BIND=True
OPENMP_NESTED=True
VERBOSE=False

#get the current process cpuset
if [[ ${HAS_HWLOC} -eq 0 ]]; then
  MY_PID="$BASHPID"
  CURRENT_CPUSET=$(hwloc-ps --cpuset | grep "${MY_PID}" | cut -f 2)
  echo "$CURRENT_CPUSET"
fi

function show_help {
  local cmd=$(basename "$0")
  echo "Usage: ${cmd} <options> -- command ..." 
  echo "  Uses hwloc to divide the node into the given number of groups,"
  echo "  set the appropriate OMP_NUM_THREADS and execute the command on the"
  echo "  selected group."
  echo ""
  echo "  NOTE: This command assumes it has exclusive use of the node"
  echo ""
  echo "Options:"
  echo "  --proc-bind=<LOC>     Set the initial process mask for the script.  "
  echo "                        LOC can be any valid location argumnet for"
  echo "                        hwloc-calc.  Defaults to the entire machine"
  echo "  --distribute=N        Distribute the current proc-bind into N groups" 
  echo "  --index=I             Use the i'th group (zero based)" 
  echo "  --openmp=M.m          Set env variables for the given OpenMP version"
  echo "                        (default 4.0)"
  echo "  --no-openmp-proc-bind Set OMP_PROC_BIND to false and unset OMP_PLACES"    
  echo "  --no-openmp-nested    Set OMP_NESTED to false"
  echo "  -v|--verbose" 
  echo "  -h|--help" 
  echo ""
  echo "Sample Usage:"
  echo "  ${cmd} --distribute=4 --index=2 -v -- command ..."
  echo ""
}

if [[ "$#" -eq 0 ]]; then
  show_help 
  exit 0
fi


for i in $@; do
  case $i in
    # number of partitions to create
    --proc-bind=*)
      PROC_BIND="${i#*=}"
      shift
      ;;
    --distribute=*)
      DISTRIBUTE="${i#*=}"
      shift
      ;;
    # which group to use
    --index=*)
      INDEX="${i#*=}"
      shift
      ;;
    --openmp=*)
      OPENMP_VERSION="${i#*=}"
      shift
      ;;
    --no-openmp-proc-bind)
      OPENMP_PROC_BIND=False
      shift
      ;;
    --no-openmp-nested)
      OPENMP_NESTED=False
      shift
      ;;
    -v|--verbose)
      VERBOSE=True
      shift
      ;;
    -h|--help)
      show_help
      exit 0
      ;;
    # ignore remaining arguments
    --)
      shift
      break
      ;;
    # unknown option
    *)
      UNKNOWN_ARGS+=("$i")
      shift
      ;;
  esac
done

if [[ ${#UNKNOWN_ARGS[*]} > 0 ]]; then
  echo "Uknown options: ${UNKNOWN_ARGS[*]}"
  exit 1
fi

if [[ ${DISTRIBUTE} -le 0 ]]; then
  echo "Invalid input for distribute, changing distribute to 1"
  DISTRIBUTE=1
fi

if [[ ${INDEX} -ge ${DISTRIBUTE} ]]; then
  echo "Invalid input for index, changing index to 0"
  INDEX=0
fi

if [[ ${HAS_HWLOC} -ne 0 ]]; then
  echo "hwloc not found, no process binding will occur"
  DISTRIBUTE=1
  INDEX=0
fi

if [[ ${HAS_HWLOC} -eq 0 ]]; then

  if [[ "${CURRENT_CPUSET}" == "" ]]; then
    BINDING=$(hwloc-calc ${PROC_BIND})
  else 
    BINDING=$(hwloc-calc --restrict ${CURRENT_CPUSET} ${PROC_BIND})
  fi

  CPUSETS=($(hwloc-distrib --restrict ${BINDING} --at core ${DISTRIBUTE}))
  CPUSET=${CPUSETS[${INDEX}]}
  NUM_THREADS=$(hwloc-ls --restrict ${CPUSET} --only pu | wc -l)

  if [[ "${VERBOSE}" == "True" ]]; then
    echo "hwloc:         true"
    echo "  proc_bind:     ${PROC_BIND}"
    echo "  distribute:    ${DISTRIBUTE}"
    echo "  index:         ${INDEX}"
    echo "  parent_cpuset: ${CURRENT_CPUSET}"
    echo "  cpuset:        ${CPUSET}"
    echo "omp_num_threads: ${NUM_THREADS}"
    echo "omp_proc_bind:   ${OPENMP_PROC_BIND}"
    echo "omp_nested:      ${OPENMP_NESTED}"
    echo "OpenMP:          ${OPENMP_VERSION}"
  fi

  # set OMP env
  if [[ "${OPENMP_PROC_BIND}" == "True" ]]; then
    if [[ "${OPENMP_VERSION}" == "4.0" || "${OPENMP_VERSION}" > "4.0" ]]; then
      export OMP_PLACES="threads"
      export OMP_PROC_BIND="spread"
    else
      export OMP_PROC_BIND="true"
      unset OMP_PLACES
    fi
  else
    unset OMP_PLACES
    unset OMP_PROC_BIND
  fi
  if [[ "${OPENMP_NESTED}" == "True" ]]; then
    export OMP_NESTED="true"
  else
    export OMP_NESTED="false"
  fi
  export OMP_NUM_THREADS="${NUM_THREADS}"

  hwloc-bind ${CPUSET} -- $@
else
  NUM_THREADS=$(cat /proc/cpuinfo | grep -c processor)

  if [[ "${VERBOSE}" == "True" ]]; then
    echo "hwloc:           false"
    echo "omp_num_threads: ${NUM_THREADS}"
    echo "omp_proc_bind:   ${OPENMP_PROC_BIND}"
    echo "omp_nested:      ${OPENMP_NESTED}"
    echo "OpenMP:          ${OPENMP_VERSION}"
  fi
    
  # set OMP env
  if [[ "${OPENMP_PROC_BIND}" == "True" ]]; then
    if [[ "${OPENMP_VERSION}" == "4.0" || "${OPENMP_VERSION}" > "4.0" ]]; then
      export OMP_PLACES="threads"
      export OMP_PROC_BIND="spread"
    else
      export OMP_PROC_BIND="true"
      unset OMP_PLACES
    fi
  else
    unset OMP_PLACES
    unset OMP_PROC_BIND
  fi
  if [[ "${OPENMP_NESTED}" == "True" ]]; then
    export OMP_NESTED="true"
  else
    export OMP_NESTED="false"
  fi
  export OMP_NUM_THREADS="${NUM_THREADS}"

  eval $@
fi

