package org.apache.spark.api.ruby

import java.io.{File, DataInputStream, InputStream, DataOutputStream, FileOutputStream} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.nio.file.Paths

import scala.collection.mutable import scala.collection.JavaConversions._

import org.apache.spark._ import org.apache.spark.api.python.PythonRDD import org.apache.spark.util.Utils import org.apache.spark.util.RedirectThread

/* =================================================================================================

* Object RubyWorker
* =================================================================================================
*
* Create and store server for creating workers.
*/

object RubyWorker extends Logging {

val PROCESS_WAIT_TIMEOUT = 10000

private var serverSocket: ServerSocket = null
private val serverHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
private var serverPort: Int = 0

private var master: ExecutedFileCommand = null
private var masterSocket: Socket = null
private var masterOutputStream: DataOutputStream = null
private var masterInputStream: DataInputStream = null

private var workers = new mutable.WeakHashMap[Socket, Long]()

/* ----------------------------------------------------------------------------------------------
 * Create new worker but first check if exist SocketServer and master process.
 * If not it will create them. Worker have 2 chance to create.
 */

def create(env: SparkEnv): (Socket, Long) = {
  synchronized {
    // Create the server if it hasn't been started
    createServer(env)

    // Attempt to connect, restart and retry once if it fails
    try {
      createWorker
    } catch {
      case exc: SocketException =>
        logWarning("Worker unexpectedly quit, attempting to restart")
        createWorker
    }
  }
}

/* ----------------------------------------------------------------------------------------------
 * Create a worker throught master process. Return new socket and id.
 * According spark.ruby.worker.type id will be:
 *   process: PID
 *   thread: thread object id
 */

def createWorker: (Socket, Long) = {
  synchronized {
    masterOutputStream.writeInt(RubyConstant.CREATE_WORKER)
    var socket = serverSocket.accept()

    var id = new DataInputStream(socket.getInputStream).readLong()
    workers.put(socket, id)

    (socket, id)
  }
}

/* ----------------------------------------------------------------------------------------------
 * Create SocketServer and bind it to the localhost. Max numbers of connection on queue
 * is set to default. If server is created withou exception -> create master.
 */

private def createServer(env: SparkEnv){
  synchronized {
    // Already running?
    if(serverSocket != null && masterSocket != null) {
      return
    }

    try {
      // Start Socket Server for comunication
      serverSocket = new ServerSocket(0, 0, serverHost)
      serverPort = serverSocket.getLocalPort

      // Create a master for worker creations
      createMaster(env)
    } catch {
      case e: Exception =>
        throw new SparkException("There was a problem with creating a server", e)
    }
  }
}

/* ----------------------------------------------------------------------------------------------
 * In this point SocketServer must be created. Master process create and kill workers.
 * Creating workers from Java can be an expensive operation because new process can
 * get copy of address space.
 */

private def createMaster(env: SparkEnv){
  synchronized {
    val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER
    val executorOptions = env.conf.get("spark.ruby.executor.options", "")
    val commandTemplate = env.conf.get("spark.ruby.executor.command")
    val workerType = env.conf.get("spark.ruby.worker.type")

    // Where is root of ruby-spark
    var executorLocation = ""

    if(isDriver){
      // Use worker from current active gem location
      executorLocation = env.conf.get("spark.ruby.driver_home")
    }
    else{
      // Use gem installed on the system
      try {
        val homeCommand = (new FileCommand(commandTemplate, "ruby-spark home", env, getEnvVars(env))).run
        executorLocation = homeCommand.readLine
      } catch {
        case e: Exception =>
          throw new SparkException("Ruby-spark gem is not installed.", e)
      }
    }

    // Master and worker are saved in GEM_ROOT/lib/spark/worker
    executorLocation = Paths.get(executorLocation, "lib", "spark", "worker").toString

    // Create master command
    // -C: change worker dir before execution
    val masterRb = s"ruby $executorOptions -C $executorLocation master.rb $workerType $serverPort"
    val masterCommand = new FileCommand(commandTemplate, masterRb, env, getEnvVars(env))

    // Start master
    master = masterCommand.run

    // Redirect master stdout and stderr
    redirectStreamsToStderr(master.getInputStream, master.getErrorStream)

    // Wait for it to connect to our socket
    serverSocket.setSoTimeout(PROCESS_WAIT_TIMEOUT)
    try {
      // Use socket for comunication. Keep stdout and stdin for log
      masterSocket = serverSocket.accept()
      masterOutputStream = new DataOutputStream(masterSocket.getOutputStream)
      masterInputStream  = new DataInputStream(masterSocket.getInputStream)

      PythonRDD.writeUTF(executorOptions, masterOutputStream)
    } catch {
      case e: Exception =>
        throw new SparkException("Ruby master did not connect back in time", e)
    }
  }
}

/* ----------------------------------------------------------------------------------------------
 * Gel all environment variables for executor
 */

def getEnvVars(env: SparkEnv): Map[String, String] = {
  val prefix = "spark.ruby.executor.env."
  env.conf.getAll.filter{case (k, _) => k.startsWith(prefix)}
                 .map{case (k, v) => (k.substring(prefix.length), v)}
                 .toMap
}

/* ------------------------------------------------------------------------------------------- */

def kill(workerId: Long){
  masterOutputStream.writeInt(RubyConstant.KILL_WORKER)
  masterOutputStream.writeLong(workerId)
}

/* ------------------------------------------------------------------------------------------- */

def killAndWait(workerId: Long){
  masterOutputStream.writeInt(RubyConstant.KILL_WORKER_AND_WAIT)
  masterOutputStream.writeLong(workerId)

  // Wait for answer
  masterInputStream.readInt() match {
    case RubyConstant.SUCCESSFULLY_KILLED =>
      logInfo(s"Worker $workerId was successfully killed")
    case RubyConstant.UNSUCCESSFUL_KILLING =>
      logInfo(s"Worker $workerId cannot be killed (maybe is already killed)")
  }
}

/* ----------------------------------------------------------------------------------------------
 * workers HashMap is week but it avoid long list of workers which cannot be killed (killAndWait)
 */

def remove(worker: Socket, workerId: Long){
  try {
    workers.remove(worker)
  } catch {
    case e: Exception => logWarning(s"Worker $workerId does not exist (maybe is already removed)")
  }
}

/* ------------------------------------------------------------------------------------------- */

def stopServer{
  synchronized {
    // Kill workers
    workers.foreach { case (socket, id) => killAndWait(id) }

    // Kill master
    master.destroy

    // Stop SocketServer
    serverSocket.close()

    // Clean variables
    serverSocket = null
    serverPort = 0
    master = null
    masterSocket = null
    masterOutputStream = null
    masterInputStream = null
  }
}

/* ------------------------------------------------------------------------------------------- */

private def redirectStreamsToStderr(streams: InputStream*) {
  try {
    for(stream <- streams) {
      new RedirectThread(stream, System.err, "stream reader").start()
    }
  } catch {
    case e: Exception =>
      logError("Exception in redirecting streams", e)
  }
}

/* ------------------------------------------------------------------------------------------- */

}