class DNN::Layers::GRUDense
Attributes
trainable[RW]
Public Class Methods
new(weight, recurrent_weight, bias)
click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 371 def initialize(weight, recurrent_weight, bias) @weight = weight @recurrent_weight = recurrent_weight @bias = bias @update_sigmoid = Layers::Sigmoid.new @reset_sigmoid = Layers::Sigmoid.new @tanh = Layers::Tanh.new @trainable = true end
Public Instance Methods
backward(dh2)
click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 404 def backward(dh2) dtanh_h = @tanh.backward_node(dh2 * (1 - @update)) dh = dh2 * @update if @trainable dweight_h = @x.transpose.dot(dtanh_h) dweight2_h = (@h * @reset).transpose.dot(dtanh_h) dbias_h = dtanh_h.sum(0) if @bias end dx = dtanh_h.dot(@weight_h.transpose) dh += dtanh_h.dot(@weight2_h.transpose) * @reset dreset = @reset_sigmoid.backward_node(dtanh_h.dot(@weight2_h.transpose) * @h) dupdate = @update_sigmoid.backward_node(dh2 * @h - dh2 * @tanh_h) da = Xumo::SFloat.hstack([dupdate, dreset]) if @trainable dweight_a = @x.transpose.dot(da) dweight2_a = @h.transpose.dot(da) dbias_a = da.sum(0) if @bias end dx += da.dot(@weight_a.transpose) dh += da.dot(@weight2_a.transpose) if @trainable @weight.grad += Xumo::SFloat.hstack([dweight_a, dweight_h]) @recurrent_weight.grad += Xumo::SFloat.hstack([dweight2_a, dweight2_h]) @bias.grad += Xumo::SFloat.hstack([dbias_a, dbias_h]) if @bias end [dx, dh] end
forward(x, h)
click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 381 def forward(x, h) @x = x @h = h num_units = h.shape[1] @weight_a = @weight.data[true, 0...(num_units * 2)] @weight2_a = @recurrent_weight.data[true, 0...(num_units * 2)] a = x.dot(@weight_a) + h.dot(@weight2_a) a += @bias.data[0...(num_units * 2)] if @bias @update = @update_sigmoid.forward_node(a[true, 0...num_units]) @reset = @reset_sigmoid.forward_node(a[true, num_units..-1]) @weight_h = @weight.data[true, (num_units * 2)..-1] @weight2_h = @recurrent_weight.data[true, (num_units * 2)..-1] @tanh_h = if @bias bias_h = @bias.data[(num_units * 2)..-1] @tanh.forward_node(x.dot(@weight_h) + (h * @reset).dot(@weight2_h) + bias_h) else @tanh.forward_node(x.dot(@weight_h) + (h * @reset).dot(@weight2_h)) end h2 = (1 - @update) * @tanh_h + @update * h h2 end