module Chainer::Utils::Initializer

Public Class Methods

get_fans(shape, device: Chainer::Device.default) click to toggle source
# File lib/chainer/utils/initializer.rb, line 4
def self.get_fans(shape, device: Chainer::Device.default)
  raise 'shape must be of length >= 2: shape={}' if shape.size < 2
  slice_arr = shape.slice(2, shape.size)
  receptive_field_size = slice_arr.empty? ? 1 : device.xm::Int32[slice_arr].prod
  fan_in = shape[1] * receptive_field_size
  fan_out = shape[0] * receptive_field_size
  [fan_in, fan_out]
end