class Chainer::Functions::Array::Cast

Public Class Methods

cast(x, type) click to toggle source

Cast an input variable to a given type.

@param x [Chainer::Variable or Numo::Narray] x : Input variable to be casted. @param type [Numo::Narray class] type : data class to cast @return [Chainer::Variable] Variable holding a casted array.

example > x = Numo::UInt8.new(3, 5).seq > x.class # => Numo::UInt8 > y = Chainer::Functions::Array::Cast.cast(x, Numo::DFloat) > y.dtype # => Numo::DFloat

# File lib/chainer/functions/array/cast.rb, line 18
def self.cast(x, type)
  if (Chainer.array?(x) && x.class == type) || (x.is_a?(Chainer::Variable) && x.dtype == type)
    return Chainer::Variable.as_variable(x)
  end
  self.new(type).apply([x]).first
end
new(type) click to toggle source
# File lib/chainer/functions/array/cast.rb, line 25
def initialize(type)
    @type = type
end

Public Instance Methods

backward(indexes, g) click to toggle source
# File lib/chainer/functions/array/cast.rb, line 34
def backward(indexes, g)
  [Cast.cast(g.first, @in_type)]
end
forward(x) click to toggle source
# File lib/chainer/functions/array/cast.rb, line 29
def forward(x)
  @in_type = x.first.class
  [x.first.cast_to(@type)]
end