class Chainer::Functions::Array::Squeeze

Public Class Methods

new(axis: nil) click to toggle source
# File lib/chainer/functions/array/squeeze.rb, line 15
def initialize(axis: nil)
  if axis.nil?
    @axis = nil
  elsif axis.kind_of?(Integer)
    @axis = [axis]
  elsif axis.kind_of?(::Array) && Array(axis).all? { |i| i.kind_of?(Integer) }
    @axis = axis
  else
    raise TypeError, 'axis must be None, int or tuple of ints'
  end
end
squeeze(x, axis: nil) click to toggle source

Remove demensions of size one from the shape of a Numo::NArray. @param [Chainer::Variable or Numo::NArray] x Input data. @param [nil or integer or array of integer] axis A subset of the single-dimensional entries in the shape to remove.

If `nil` is supplied, all of them are removed. The dimension index starts at zero.
If an axis with dimension greater than one is selected, an error is raised.

@return [Chainer::Variable] Variable whose dimensions of size 1 are removed.

# File lib/chainer/functions/array/squeeze.rb, line 11
def self.squeeze(x, axis: nil)
  self.new(axis: axis).apply([x]).first
end

Public Instance Methods

backward(indexes, grad_outputs) click to toggle source
# File lib/chainer/functions/array/squeeze.rb, line 47
def backward(indexes, grad_outputs)
  if @axis.nil?
    axis = argone(@inputs[0].shape)
  else
    axis = @axis
    ndim = @inputs[0].shape.size
    axis = axis.map { |x| x < 0 ? x + ndim : x }
    axis.sort!
  end
  gx = grad_outputs.first

  shape = gx.shape
  axis.each do |x|
    shape.insert(x, 1)
  end
  [gx.reshape(*shape)]
end
forward(inputs) click to toggle source
# File lib/chainer/functions/array/squeeze.rb, line 27
def forward(inputs)
  x = inputs.first
  shape = x.shape

  # TODO: numpy.squeeze
  if @axis.nil?
    new_shape = shape.reject { |axis| axis == 1 }
  else
    new_shape = shape
    @axis.map do |a|
      raise StandardError, "cannot select an axis to squeeze out which has size not equal to one" unless shape[a] == 1
      new_shape[a] = nil
    end
    new_shape.compact!
  end
  ret = new_shape.size.zero? ? x.class.new.fill(x[0]) : x.reshape(*new_shape)

  [ret]
end

Private Instance Methods

argone(iterable) click to toggle source
# File lib/chainer/functions/array/squeeze.rb, line 67
def argone(iterable)
  result = []
  Array(iterable).each_with_index do |x, i|
    raise StandardError, "elements in iterable must be int" unless x.kind_of?(Integer)
    result << i if x == 1
  end
  result
end