package org.apache.spark.api.ruby.marshal
import java.io.{DataInputStream, DataOutputStream, ByteArrayInputStream, ByteArrayOutputStream}
import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ import scala.reflect.{ClassTag, classTag}
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector}
/* =================================================================================================
* class MarshalLoad * ================================================================================================= */
class MarshalLoad(is: DataInputStream) {
case class WaitForObject() val registeredSymbols = ArrayBuffer[String]() val registeredLinks = ArrayBuffer[Any]() def load: Any = { load(is.readUnsignedByte.toChar) } def load(dataType: Char): Any = { dataType match { case '0' => null case 'T' => true case 'F' => false case 'i' => loadInt case 'f' => loadAndRegisterFloat case ':' => loadAndRegisterSymbol case '[' => loadAndRegisterArray case 'U' => loadAndRegisterUserObject case _ => throw new IllegalArgumentException(s"Format is not supported: $dataType.") } } // ---------------------------------------------------------------------------------------------- // Load by type def loadInt: Int = { var c = is.readByte.toInt if (c == 0) { return 0 } else if (4 < c && c < 128) { return c - 5 } else if (-129 < c && c < -4) { return c + 5 } var result: Long = 0 if (c > 0) { result = 0 for( i <- 0 until c ) { result |= (is.readUnsignedByte << (8 * i)).toLong } } else { c = -c result = -1 for( i <- 0 until c ) { result &= ~((0xff << (8 * i)).toLong) result |= (is.readUnsignedByte << (8 * i)).toLong } } result.toInt } def loadAndRegisterFloat: Double = { val result = loadFloat registeredLinks += result result } def loadFloat: Double = { val string = loadString string match { case "nan" => Double.NaN case "inf" => Double.PositiveInfinity case "-inf" => Double.NegativeInfinity case _ => string.toDouble } } def loadString: String = { new String(loadStringBytes) } def loadStringBytes: Array[Byte] = { val size = loadInt val buffer = new Array[Byte](size) var readSize = 0 while(readSize < size){ val read = is.read(buffer, readSize, size-readSize) if(read == -1){ throw new IllegalArgumentException("Marshal too short.") } readSize += read } buffer } def loadAndRegisterSymbol: String = { val result = loadString registeredSymbols += result result } def loadAndRegisterArray: Array[Any] = { val size = loadInt val array = new Array[Any](size) registeredLinks += array for( i <- 0 until size ) { array(i) = loadNextObject } array } def loadAndRegisterUserObject: Any = { val klass = loadNextObject.asInstanceOf[String] // Register future class before load the next object registeredLinks += WaitForObject() val index = registeredLinks.size - 1 val data = loadNextObject val result = klass match { case "Spark::Mllib::LabeledPoint" => createLabeledPoint(data) case "Spark::Mllib::DenseVector" => createDenseVector(data) case "Spark::Mllib::SparseVector" => createSparseVector(data) case other => throw new IllegalArgumentException(s"Object $other is not supported.") } registeredLinks(index) = result result } // ---------------------------------------------------------------------------------------------- // Other loads def loadNextObject: Any = { val dataType = is.readUnsignedByte.toChar if(isLinkType(dataType)){ readLink(dataType) } else{ load(dataType) } } // ---------------------------------------------------------------------------------------------- // To java objects def createLabeledPoint(data: Any): LabeledPoint = { val array = data.asInstanceOf[Array[_]] new LabeledPoint(array(0).asInstanceOf[Double], array(1).asInstanceOf[Vector]) } def createDenseVector(data: Any): DenseVector = { new DenseVector(data.asInstanceOf[Array[_]].map(toDouble(_))) } def createSparseVector(data: Any): SparseVector = { val array = data.asInstanceOf[Array[_]] val size = array(0).asInstanceOf[Int] val indices = array(1).asInstanceOf[Array[_]].map(_.asInstanceOf[Int]) val values = array(2).asInstanceOf[Array[_]].map(toDouble(_)) new SparseVector(size, indices, values) } // ---------------------------------------------------------------------------------------------- // Helpers def toDouble(data: Any): Double = data match { case x: Int => x.toDouble case x: Double => x case _ => 0.0 } // ---------------------------------------------------------------------------------------------- // Cache def readLink(dataType: Char): Any = { val index = loadInt dataType match { case '@' => registeredLinks(index) case ';' => registeredSymbols(index) } } def isLinkType(dataType: Char): Boolean = { dataType == ';' || dataType == '@' }
}