class DNN::Losses::Loss

Public Class Methods

call(y, t, *args) click to toggle source
# File lib/dnn/core/losses.rb, line 5
def self.call(y, t, *args)
  new(*args).(y, t)
end
from_hash(hash) click to toggle source
# File lib/dnn/core/losses.rb, line 9
def self.from_hash(hash)
  return nil unless hash
  loss_class = DNN.const_get(hash[:class])
  loss = loss_class.allocate
  raise DNNError, "#{loss.class} is not an instance of #{self} class." unless loss.is_a?(self)
  loss.load_hash(hash)
  loss
end

Public Instance Methods

call(y, t) click to toggle source
# File lib/dnn/core/losses.rb, line 18
def call(y, t)
  forward(y, t)
end
clean() click to toggle source
# File lib/dnn/core/losses.rb, line 55
def clean
  hash = to_hash
  instance_variables.each do |ivar|
    instance_variable_set(ivar, nil)
  end
  load_hash(hash)
end
forward(y, t) click to toggle source
# File lib/dnn/core/losses.rb, line 32
def forward(y, t)
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
end
load_hash(hash) click to toggle source
# File lib/dnn/core/losses.rb, line 51
def load_hash(hash)
  initialize
end
loss(y, t, layers: nil, loss_weight: nil) click to toggle source
# File lib/dnn/core/losses.rb, line 22
def loss(y, t, layers: nil, loss_weight: nil)
  unless y.shape == t.shape
    raise DNNShapeError, "The shape of y does not match the t shape. y shape is #{y.shape}, but t shape is #{t.shape}."
  end
  loss = call(y, t)
  loss *= loss_weight if loss_weight
  loss = regularizers_forward(loss, layers) if layers
  loss
end
regularizers_forward(loss, layers) click to toggle source
# File lib/dnn/core/losses.rb, line 36
def regularizers_forward(loss, layers)
  regularizers = layers.select { |layer| layer.respond_to?(:regularizers) }
                       .map(&:regularizers).flatten
  regularizers.each do |regularizer|
    loss = regularizer.forward(loss)
  end
  loss
end
to_hash(merge_hash = nil) click to toggle source
# File lib/dnn/core/losses.rb, line 45
def to_hash(merge_hash = nil)
  hash = { class: self.class.name }
  hash.merge!(merge_hash) if merge_hash
  hash
end