class Chainer::Links::Model::Classifier

Attributes

compute_accuracy[RW]

Public Class Methods

new(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy), label_key=-1) click to toggle source

@param [Chainer::Link] predictor Predictor network. @param [Function] lossfun Loss function. @param [Function] accfun Function that computes accuracy. @param [Integer, String] label_key Key to specify label variable from arguments.

When it is Integer, a variable in positional arguments is used.
And when it is String, a variable in keyword arguments is used.
Calls superclass method Chainer::Chain::new
# File lib/chainer/links/model/classifier.rb, line 13
def initialize(predictor, lossfun=Functions::Loss::SoftmaxCrossEntropy.method(:softmax_cross_entropy), accfun=Functions::Evaluation::Accuracy.method(:accuracy), label_key=-1)
  super()

  unless label_key.is_a?(Integer) || label_key.is_a?(String)
    raise TypeError, "label_key must be Integer or String, but is #{label_key.class}"
  end

  @lossfun = lossfun
  @accfun = accfun
  @y = nil
  @loss = nil
  @accuracy = nil
  @compute_accuracy = true
  @label_key = label_key

  init_scope do
    @predictor = predictor
  end
end

Public Instance Methods

call(*args, **kwargs) click to toggle source
# File lib/chainer/links/model/classifier.rb, line 33
def call(*args, **kwargs)
  if @label_key.is_a?(Integer)
    raise IndexError, "label_key #{@label_key} is out of bounds" if @label_key < -args.size || @label_key >= args.size
    t = args.slice!(@label_key)
  elsif @label_key.is_a?(String)
    raise KeyError, "label_key #{@label_key} is not found" unless kwargs.has_key?(@label_key)
    t = kwargs[@label_key]
    kwargs.delete(@label_key)
  end

  @y = nil
  @accuracy = nil
  @y = @predictor.(*args, **kwargs)

  @loss = @lossfun.call(@y, t)
  Chainer::Reporter.save_report({loss: @loss}, self)
  if @compute_accuracy
    @accuracy = @accfun.call(@y, t)
    Chainer::Reporter.save_report({accuracy: @accuracy}, self)
  end
  @loss
end