class Chainer::Functions::Evaluation::Accuracy

Public Class Methods

accuracy(y, t, ignore_label: nil) click to toggle source
# File lib/chainer/functions/evaluation/accuracy.rb, line 5
def self.accuracy(y, t, ignore_label: nil)
  self.new(ignore_label: ignore_label).(y, t)
end
new(ignore_label: nil) click to toggle source
# File lib/chainer/functions/evaluation/accuracy.rb, line 9
def initialize(ignore_label: nil)
  @ignore_label = ignore_label
end

Public Instance Methods

forward(inputs) click to toggle source
# File lib/chainer/functions/evaluation/accuracy.rb, line 13
def forward(inputs)
  y, t = inputs
  xm = Chainer.get_array_module(*inputs)
  if @ignore_label
    mask = t.eq(@ignore_label)
    ignore_cnt = mask.count

    pred = y.max_index(axis: 1) - xm::Int32.new(y.shape[0]).seq(0, y.shape[1])
    pred = pred.reshape(*t.shape)
    pred[mask] = @ignore_label
    count = pred.eq(t).count - ignore_cnt

    total = t.size - ignore_cnt

    if total == 0
      [y.class.cast(0.0)]
    else
      [y.class.cast(count.to_f / total)]
    end
  else
    pred = y.max_index(axis: 1) - xm::Int32.new(y.shape[0]).seq(0, y.shape[1])
    pred = pred.reshape(*t.shape)

    [y.class.cast(y.class[pred.eq(t)].mean)]
  end
end