class OnnxChainer::Operators::Gemm

Public Class Methods

new(input_names:, output_shape:, output_names:, instance_variable_name:, need_initialized:) click to toggle source
# File lib/onnx-chainer/operators/gemm.rb, line 23
def initialize(input_names:, output_shape:, output_names:, instance_variable_name:, need_initialized:)
  @input_names = input_names
  @output_shape = output_shape
  @output_names = output_names
  @instance_variable_name = instance_variable_name
  @need_initialized = need_initialized
end
parse(node, input_names, inputs, output_name_index) click to toggle source
# File lib/onnx-chainer/operators/gemm.rb, line 7
def parse(node, input_names, inputs, output_name_index)
  bias_name = node.input.find { |i| i.match(/_b$/) }
  input = inputs.find { |i| i.name == bias_name }
  output_shape = input.type.tensor_type.shape.dim.map(&:dim_value)

  need_initialized = node.input.any? { |i| inputs.map(&:name).include?(i) }

  output_names = {
    node.output.first => "l#{output_name_index}"
  }
  instance_variable_name = "@l#{output_name_index}"

  self.new(input_names: input_names, output_shape: output_shape, output_names: output_names, instance_variable_name: instance_variable_name, need_initialized: need_initialized)
end

Public Instance Methods

chainer_class() click to toggle source
# File lib/onnx-chainer/operators/gemm.rb, line 31
def chainer_class
  ::Chainer::Links::Connection::Linear
end
to_call_string(args) click to toggle source
# File lib/onnx-chainer/operators/gemm.rb, line 39
def to_call_string(args)
  "#{@output_names.values.first} = #{@instance_variable_name}.(#{args.join(', ')})"
end
to_initialize_string() click to toggle source
# File lib/onnx-chainer/operators/gemm.rb, line 35
def to_initialize_string
  "#{@instance_variable_name} = #{chainer_class}.new(nil, out_size: #{@output_shape})"
end