class DNN::Layers::GRUDense

Attributes

trainable[RW]

Public Class Methods

new(weight, recurrent_weight, bias) click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 371
def initialize(weight, recurrent_weight, bias)
  @weight = weight
  @recurrent_weight = recurrent_weight
  @bias = bias
  @update_sigmoid = Layers::Sigmoid.new
  @reset_sigmoid = Layers::Sigmoid.new
  @tanh = Layers::Tanh.new
  @trainable = true
end

Public Instance Methods

backward(dh2) click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 404
def backward(dh2)
  dtanh_h = @tanh.backward_node(dh2 * (1 - @update))
  dh = dh2 * @update

  if @trainable
    dweight_h = @x.transpose.dot(dtanh_h)
    dweight2_h = (@h * @reset).transpose.dot(dtanh_h)
    dbias_h = dtanh_h.sum(0) if @bias
  end
  dx = dtanh_h.dot(@weight_h.transpose)
  dh += dtanh_h.dot(@weight2_h.transpose) * @reset

  dreset = @reset_sigmoid.backward_node(dtanh_h.dot(@weight2_h.transpose) * @h)
  dupdate = @update_sigmoid.backward_node(dh2 * @h - dh2 * @tanh_h)
  da = Xumo::SFloat.hstack([dupdate, dreset])
  if @trainable
    dweight_a = @x.transpose.dot(da)
    dweight2_a = @h.transpose.dot(da)
    dbias_a = da.sum(0) if @bias
  end
  dx += da.dot(@weight_a.transpose)
  dh += da.dot(@weight2_a.transpose)

  if @trainable
    @weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h])
    @recurrent_weight.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h])
    @bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h]) if @bias
  end
  [dx, dh]
end
forward(x, h) click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 381
def forward(x, h)
  @x = x
  @h = h
  num_units = h.shape[1]
  @weight_a = @weight.data[true, 0...(num_units * 2)]
  @weight2_a = @recurrent_weight.data[true, 0...(num_units * 2)]
  a = x.dot(@weight_a) + h.dot(@weight2_a)
  a += @bias.data[0...(num_units * 2)] if @bias
  @update = @update_sigmoid.forward_node(a[true, 0...num_units])
  @reset = @reset_sigmoid.forward_node(a[true, num_units..-1])

  @weight_h = @weight.data[true, (num_units * 2)..-1]
  @weight2_h = @recurrent_weight.data[true, (num_units * 2)..-1]
  @tanh_h = if @bias
              bias_h = @bias.data[(num_units * 2)..-1]
              @tanh.forward_node(x.dot(@weight_h) + (h * @reset).dot(@weight2_h) + bias_h)
            else
              @tanh.forward_node(x.dot(@weight_h) + (h * @reset).dot(@weight2_h))
            end
  h2 = (1 - @update) * @tanh_h + @update * h
  h2
end