class Chainer::Functions::Array::Rollaxis

Roll axis of an array.

Public Class Methods

new(axis, start) click to toggle source
# File lib/chainer/functions/array/rollaxis.rb, line 16
def initialize(axis, start)
  unless axis.is_a?(Integer)
    raise ArgumentError, 'axis must be int'
  end

  unless start.is_a?(Integer)
    raise ArgumentError, 'start must be int'
  end

  @axis = axis
  @start = start
end
rollaxis(x, axis, start: 0) click to toggle source

Roll the axis backwards to the given position.

@param [Chainer::Variable] x Input variable @param [Integer] axis The axis to roll backwards. @param [Integer] start The place to which the axis is moved. @return [Chainer::Variable] Variable whose axis is rolled.

# File lib/chainer/functions/array/rollaxis.rb, line 12
def self.rollaxis(x, axis, start: 0)
  Rollaxis.new(axis, start).apply([x]).first
end

Public Instance Methods

backward(indexes, gy) click to toggle source
# File lib/chainer/functions/array/rollaxis.rb, line 36
def backward(indexes, gy)
  axis = @axis
  if axis < 0
    axis += @in_ndim
  end
  start = @start
  if start < 0
    start += @in_ndim
  end

  if axis > start
    axis += 1
  else
    start -= 1
  end

  Rollaxis.new(start, axis).apply(gy)
end
forward(inputs) click to toggle source
# File lib/chainer/functions/array/rollaxis.rb, line 29
def forward(inputs)
  retain_inputs([])
  @in_ndim = inputs.first.ndim

  [Chainer::Utils::Array.rollaxis(inputs.first, @axis, start: @start)]
end