class DNN::Layers::RNN

Super class of all RNN classes.

Attributes

hidden[R]
num_units[R]
recurrent_weight[R]
recurrent_weight_initializer[R]
recurrent_weight_regularizer[R]
return_sequences[R]
stateful[R]

Public Class Methods

new(num_units, stateful: false, return_sequences: true, weight_initializer: Initializers::RandomNormal.new, recurrent_weight_initializer: Initializers::RandomNormal.new, bias_initializer: Initializers::Zeros.new, weight_regularizer: nil, recurrent_weight_regularizer: nil, bias_regularizer: nil, use_bias: true) click to toggle source

@param [Integer] num_units Number of nodes. @param [Boolean] stateful Maintain state between batches. @param [Boolean] return_sequences Set the false, only the last of each cell of RNN is left. @param [DNN::Initializers::Initializer] recurrent_weight_initializer Recurrent weight initializer. @param [DNN::Regularizers::Regularizer | NilClass] recurrent_weight_regularizer Recurrent weight regularizer.

Calls superclass method DNN::Layers::Connection::new
# File lib/dnn/core/layers/rnn_layers.rb, line 21
def initialize(num_units,
               stateful: false,
               return_sequences: true,
               weight_initializer: Initializers::RandomNormal.new,
               recurrent_weight_initializer: Initializers::RandomNormal.new,
               bias_initializer: Initializers::Zeros.new,
               weight_regularizer: nil,
               recurrent_weight_regularizer: nil,
               bias_regularizer: nil,
               use_bias: true)
  super(weight_initializer: weight_initializer, bias_initializer: bias_initializer,
        weight_regularizer: weight_regularizer, bias_regularizer: bias_regularizer, use_bias: use_bias)
  @num_units = num_units
  @stateful = stateful
  @return_sequences = return_sequences
  @hidden_layers = []
  @hidden = Param.new
  @recurrent_weight = Param.new(nil, Xumo::SFloat[0])
  @recurrent_weight_initializer = recurrent_weight_initializer
  @recurrent_weight_regularizer = recurrent_weight_regularizer
end

Public Instance Methods

backward_node(dh2s) click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 69
def backward_node(dh2s)
  unless @return_sequences
    dh = dh2s
    dh2s = Xumo::SFloat.zeros(dh.shape[0], @time_length, dh.shape[1])
    dh2s[true, -1, false] = dh
  end
  dxs = Xumo::SFloat.zeros(@xs_shape)
  dh = 0
  (dh2s.shape[1] - 1).downto(0) do |t|
    dh2 = dh2s[true, t, false]
    dx, dh = @hidden_layers[t].backward(dh2 + dh)
    dxs[true, t, false] = dx
  end
  dxs
end
build(input_shape) click to toggle source
Calls superclass method DNN::Layers::Layer#build
# File lib/dnn/core/layers/rnn_layers.rb, line 43
def build(input_shape)
  unless input_shape.length == 2
    raise DNNShapeError, "Input shape is #{input_shape}. But input shape must be 2 dimensional."
  end
  super
end
compute_output_shape() click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 85
def compute_output_shape
  @time_length = @input_shape[0]
  @return_sequences ? [@time_length, @num_units] : [@num_units]
end
forward_node(xs) click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 54
def forward_node(xs)
  create_hidden_layer
  @xs_shape = xs.shape
  hs = Xumo::SFloat.zeros(xs.shape[0], @time_length, @num_units)
  h = @stateful && @hidden.data ? @hidden.data : Xumo::SFloat.zeros(xs.shape[0], @num_units)
  xs.shape[1].times do |t|
    x = xs[true, t, false]
    @hidden_layers[t].trainable = @trainable
    h = @hidden_layers[t].forward(x, h)
    hs[true, t, false] = h
  end
  @hidden.data = h
  @return_sequences ? hs : h
end
get_params() click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 115
def get_params
  { weight: @weight, recurrent_weight: @recurrent_weight, bias: @bias, hidden: @hidden }
end
load_hash(hash) click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 102
def load_hash(hash)
  initialize(hash[:num_units],
             stateful: hash[:stateful],
             return_sequences: hash[:return_sequences],
             weight_initializer: Initializers::Initializer.from_hash(hash[:weight_initializer]),
             recurrent_weight_initializer: Initializers::Initializer.from_hash(hash[:recurrent_weight_initializer]),
             bias_initializer: Initializers::Initializer.from_hash(hash[:bias_initializer]),
             weight_regularizer: Regularizers::Regularizer.from_hash(hash[:weight_regularizer]),
             recurrent_weight_regularizer: Regularizers::Regularizer.from_hash(hash[:recurrent_weight_regularizer]),
             bias_regularizer: Regularizers::Regularizer.from_hash(hash[:bias_regularizer]),
             use_bias: hash[:use_bias])
end
regularizers() click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 124
def regularizers
  regularizers = []
  regularizers << @weight_regularizer if @weight_regularizer
  regularizers << @recurrent_weight_regularizer if @recurrent_weight_regularizer
  regularizers << @bias_regularizer if @bias_regularizer
  regularizers
end
reset_state() click to toggle source

Reset the state of RNN.

# File lib/dnn/core/layers/rnn_layers.rb, line 120
def reset_state
  @hidden.data = @hidden.data.fill(0) if @hidden.data
end
to_hash(merge_hash = nil) click to toggle source
Calls superclass method DNN::Layers::Connection#to_hash
# File lib/dnn/core/layers/rnn_layers.rb, line 90
def to_hash(merge_hash = nil)
  hash = {
    num_units: @num_units,
    stateful: @stateful,
    return_sequences: @return_sequences,
    recurrent_weight_initializer: @recurrent_weight_initializer.to_hash,
    recurrent_weight_regularizer: @recurrent_weight_regularizer&.to_hash,
  }
  hash.merge!(merge_hash) if merge_hash
  super(hash)
end

Private Instance Methods

create_hidden_layer() click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 50
        def create_hidden_layer
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'create_hidden_layer'"
end
init_weight_and_bias() click to toggle source
# File lib/dnn/core/layers/rnn_layers.rb, line 132
        def init_weight_and_bias
  super
  @recurrent_weight_initializer.init_param(self, @recurrent_weight)
  @recurrent_weight_regularizer.param = @recurrent_weight if @recurrent_weight_regularizer
end