package org.apache.spark.api.ruby
import java.io._ import java.net._ import java.util.{List, ArrayList, Collections}
import scala.util.Try import scala.reflect.ClassTag import scala.collection.JavaConversions._
import org.apache.spark._ import org.apache.spark.{SparkEnv, Partition, SparkException, TaskContext} import org.apache.spark.api.ruby._ import org.apache.spark.api.ruby.marshal._ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.api.python.PythonRDD import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils import org.apache.spark.InterruptibleIterator
/* =================================================================================================
* Class RubyRDD * ================================================================================================= */
class RubyRDD(
@transient parent: RDD[_], command: Array[Byte], broadcastVars: ArrayList[Broadcast[RubyBroadcast]], accumulator: Accumulator[List[Array[Byte]]]) extends RDD[Array[Byte]](parent){ val bufferSize = conf.getInt("spark.buffer.size", 65536) val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def getPartitions: Array[Partition] = firstParent.partitions override val partitioner = None /* ------------------------------------------------------------------------------------------ */ override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val env = SparkEnv.get // Get worker and id val (worker, workerId) = RubyWorker.create(env) // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() writerThread.join() // Cleanup the worker socket. This will also cause the worker to exit. try { RubyWorker.remove(worker, workerId) worker.close() } catch { case e: Exception => logWarning("Failed to close worker socket", e) } } val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) // Send data writerThread.start() // For violent termination of worker new MonitorThread(workerId, worker, context).start() // Return an iterator that read lines from the process's stdout val stdoutIterator = new StreamReader(stream, writerThread, context) // An iterator that wraps around an existing iterator to provide task killing functionality. new InterruptibleIterator(context, stdoutIterator) } // end compute /* ------------------------------------------------------------------------------------------ */ class WriterThread(env: SparkEnv, worker: Socket, split: Partition, context: TaskContext) extends Thread("stdout writer for worker") { @volatile private var _exception: Exception = null setDaemon(true) // Contains the exception thrown while writing the parent iterator to the process. def exception: Option[Exception] = Option(_exception) // Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. def shutdownOnTaskCompletion() { assert(context.isCompleted) this.interrupt() } // ------------------------------------------------------------------------------------------- // Send the necessary data for worker // - split index // - command // - iterator override def run(): Unit = Utils.logUncaughtExceptions { try { SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index dataOut.writeInt(split.index) // Spark files PythonRDD.writeUTF(SparkFiles.getRootDirectory, dataOut) // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { dataOut.writeLong(broadcast.value.id) PythonRDD.writeUTF(broadcast.value.path, dataOut) } // Serialized command dataOut.writeInt(command.length) dataOut.write(command) // Send it dataOut.flush() // Data PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) dataOut.writeInt(RubyConstant.DATA_EOF) dataOut.flush() } catch { case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e } finally { Try(worker.shutdownOutput()) // kill worker process } } } // end WriterThread /* ------------------------------------------------------------------------------------------ */ class StreamReader(stream: DataInputStream, writerThread: WriterThread, context: TaskContext) extends Iterator[Array[Byte]] { def hasNext = _nextObj != null var _nextObj = read() // ------------------------------------------------------------------------------------------- def next(): Array[Byte] = { val obj = _nextObj if (hasNext) { _nextObj = read() } obj } // ------------------------------------------------------------------------------------------- private def read(): Array[Byte] = { if (writerThread.exception.isDefined) { throw writerThread.exception.get } try { stream.readInt() match { case length if length > 0 => val obj = new Array[Byte](length) stream.readFully(obj) obj case RubyConstant.WORKER_DONE => val numAccumulatorUpdates = stream.readInt() (1 to numAccumulatorUpdates).foreach { _ => val updateLen = stream.readInt() val update = new Array[Byte](updateLen) stream.readFully(update) accumulator += Collections.singletonList(update) } null case RubyConstant.WORKER_ERROR => // Exception from worker // message val length = stream.readInt() val obj = new Array[Byte](length) stream.readFully(obj) // stackTrace val stackTraceLen = stream.readInt() val stackTrace = new Array[String](stackTraceLen) (0 until stackTraceLen).foreach { i => val length = stream.readInt() val obj = new Array[Byte](length) stream.readFully(obj) stackTrace(i) = new String(obj, "utf-8") } // Worker will be killed stream.close // exception val exception = new RubyException(new String(obj, "utf-8"), writerThread.exception.getOrElse(null)) exception.appendToStackTrace(stackTrace) throw exception } } catch { case e: Exception if context.isInterrupted => logDebug("Exception thrown after task interruption", e) throw new TaskKilledException case e: Exception if writerThread.exception.isDefined => logError("Worker exited unexpectedly (crashed)", e) throw writerThread.exception.get case eof: EOFException => throw new SparkException("Worker exited unexpectedly (crashed)", eof) } } } // end StreamReader /* --------------------------------------------------------------------------------------------- * Monitor thread for controll worker. Kill worker if task is interrupted. */ class MonitorThread(workerId: Long, worker: Socket, context: TaskContext) extends Thread("Worker Monitor for worker") { setDaemon(true) override def run() { // Kill the worker if it is interrupted, checking until task completion. while (!context.isInterrupted && !context.isCompleted) { Thread.sleep(2000) } if (!context.isCompleted) { try { logWarning("Incomplete task interrupted: Attempting to kill Worker "+workerId.toString()) RubyWorker.kill(workerId) } catch { case e: Exception => logError("Exception when trying to kill worker "+workerId.toString(), e) } } } } // end MonitorThread } // end RubyRDD
/* =================================================================================================
* Class PairwiseRDD * ================================================================================================= * * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Ruby. * This is used by PySpark's shuffle operations. * Borrowed from Python Package -> need new deserializeLongValue -> * Marshal will add the same 4b header */
class PairwiseRDD(prev: RDD[Array]) extends RDD[(Long, Array)](prev) {
override def getPartitions = prev.partitions override def compute(split: Partition, context: TaskContext) = prev.iterator(split, context).grouped(2).map { case Seq(a, b) => (Utils.deserializeLongValue(a.reverse), b) case x => throw new SparkException("PairwiseRDD: unexpected value: " + x) } val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}
/* =================================================================================================
* Object RubyRDD * ================================================================================================= */
object RubyRDD extends Logging {
def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], partitions: ArrayList[Int], allowLocal: Boolean, filename: String): String = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) writeRDDToFile(flattenedPartition.iterator, filename) } def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new BufferedInputStream(new FileInputStream(filename))) val objs = new collection.mutable.ArrayBuffer[Array[Byte]] try { while (true) { val length = file.readInt() val obj = new Array[Byte](length) file.readFully(obj) objs.append(obj) } } catch { case eof: EOFException => {} } JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } def writeRDDToFile[T](items: Iterator[T], filename: String): String = { val file = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(filename))) try { PythonRDD.writeIteratorToStream(items, file) } finally { file.close() } filename } def writeRDDToFile[T](rdd: RDD[T], filename: String): String = { writeRDDToFile(rdd.collect.iterator, filename) } def readBroadcastFromFile(sc: JavaSparkContext, path: String, id: java.lang.Long): Broadcast[RubyBroadcast] = { sc.broadcast(new RubyBroadcast(path, id)) } /** * Convert an RDD of serialized Ruby objects to RDD of objects, that is usable in Java. */ def toJava(rbRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { rbRDD.rdd.mapPartitions { iter => iter.flatMap { item => val obj = Marshal.load(item) if(batched){ obj.asInstanceOf[Array[_]] } else{ Seq(item) } } }.toJavaRDD() } /** * Convert an RDD of Java objects to an RDD of serialized Ruby objects, that is usable by Ruby. */ def toRuby(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = { jRDD.rdd.mapPartitions { iter => new IterableMarshaller(iter) } }
}
/* =================================================================================================
* Class RubyException * ================================================================================================= */
class RubyException(msg: String, cause: Exception) extends RuntimeException(msg, cause) {
def appendToStackTrace(toAdded: Array[String]) { val newStactTrace = getStackTrace.toBuffer var regexpMatch = "(.*):([0-9]+):in `([a-z]+)'".r for(item <- toAdded) { item match { case regexpMatch(fileName, lineNumber, methodName) => newStactTrace += new StackTraceElement("RubyWorker", methodName, fileName, lineNumber.toInt) case _ => null } } setStackTrace(newStactTrace.toArray) }
}