class OnnxChainer::Graph

Attributes

input_names[R]
nodes[R]
output_names[R]

Public Class Methods

new(model_name, nodes, input_names, output_names, target) click to toggle source
# File lib/onnx-chainer/graph.rb, line 78
def initialize(model_name, nodes, input_names, output_names, target)
  @model_name = model_name
  @nodes = nodes
  @input_names = input_names
  @output_names = output_names
  @target = target
end
parse(onnx_graph) click to toggle source
# File lib/onnx-chainer/graph.rb, line 10
def parse(onnx_graph)
  nodes = onnx_graph.node
  initializers = onnx_graph.initializer
  outputs = onnx_graph.output

  # take out input
  initializer_names = onnx_graph.initializer.map(&:name)
  call_inputs = onnx_graph.input.reject { |i| initializer_names.include?(i.name) }
  name = 'x'
  input_names = call_inputs.each_with_object({}) do |i, hash|
    hash[i.name] = name
    name = name.succ
  end

  # parse each node
  output_name_index = {}
  nodes = nodes.map do |n|
    output_name_index[n.op_type] ||= 1
    klass = operator_klass(n.op_type)
    i_names = n.input.reject { |i| initializers.map(&:name).include?(i) }

    node = klass.parse(n, i_names, onnx_graph.input, output_name_index[n.op_type])

    output_name_index[n.op_type] += 1
    node
  end

  # take out output
  output_names = {}
  nodes.each { |n| output_names.merge!(n.output_names) }

  # parameter
  target = {}
  onnx_graph.initializer.each do |initializer|
    name = initializer.name
    dtype = dtype(initializer.data_type)

    arr = dtype.from_binary(initializer.raw_data).reshape(*initializer.dims)

    n = name.split('_')
    target["/@#{n[1].downcase}/@#{n[2].downcase}"] =  dtype.from_binary(initializer.raw_data).reshape(*initializer.dims)
  end

  self.new(onnx_graph.name, nodes, input_names, output_names, target)
end

Private Class Methods

dtype(data_type) click to toggle source
# File lib/onnx-chainer/graph.rb, line 67
def dtype(data_type)
  if data_type == Onnx::TensorProto::DataType::FLOAT
    Numo::SFloat
  elsif data_type == Onnx::TensorProto::DataType::INT8
    Numo::Int8
  else
    raise TypeError, 'unexpected value ' + data_type
  end
end
operator_klass(op_type) click to toggle source
# File lib/onnx-chainer/graph.rb, line 58
def operator_klass(op_type)
  case op_type
  when 'Gemm' then
    return OnnxChainer::Operators::Gemm
  when 'Relu' then
    return OnnxChainer::Operators::Relu
  end
end

Public Instance Methods

export(output_dir: nil, model_name: nil) click to toggle source

export file

# File lib/onnx-chainer/graph.rb, line 87
    def export(output_dir: nil, model_name: nil)
      model_name = model_name || @model_name
      model_name = model_name.capitalize.gsub(/(?:^|_)(.)/){$1.upcase}

      output_dir ||= '.'
      FileUtils.mkdir(output_dir) unless Dir.exist?(output_dir)

s = <<EOS
require 'chainer'

class #{model_name} < Chainer::Chain
  def initialize
    super()
    init_scope do
      #{@nodes.select(&:need_initialized).map(&:to_initialize_string).join("\n      ")}
    end
  end

  def call(#{@input_names.values.join(', ')})
    #{
      @nodes.map do |n|
        args = n.input_names.map { |name| @input_names[name] || @output_names[name] }
        n.to_call_string(args)
      end.join("\n    ")
    }
  end
end
EOS

      File.open("#{output_dir}/model.rb", 'w') do |f|
        f.puts(s)
      end

      File.open("#{output_dir}/resume", 'wb+') do |f|
        Marshal.dump(@target, f)
      end
    end