class DNN::Models::Chain
Public Class Methods
new()
click to toggle source
# File lib/dnn/core/models.rb, line 42 def initialize @layers_cache = nil end
Public Instance Methods
call(input_tensors)
click to toggle source
Forward propagation and create a link. @param [Tensor | Array] input_tensors Input tensors. @return [Tensor] Output tensor.
# File lib/dnn/core/models.rb, line 56 def call(input_tensors) forward(input_tensors) end
forward(input_tensors)
click to toggle source
Forward propagation. @param [Tensor] input_tensors Input tensors. @return [Tensor] Output tensor.
# File lib/dnn/core/models.rb, line 49 def forward(input_tensors) raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'" end
layers()
click to toggle source
Get the all layers. @return [Array] All layers array.
# File lib/dnn/core/models.rb, line 62 def layers return @layers_cache if @layers_cache layers_array = [] instance_variables.sort.each do |ivar| obj = instance_variable_get(ivar) if obj.is_a?(Layers::Layer) layers_array << obj elsif obj.is_a?(Chain) || obj.is_a?(LayersList) layers_array.concat(obj.layers) end end @layers_cache = layers_array end
load_hash(layers_hash)
click to toggle source
# File lib/dnn/core/models.rb, line 89 def load_hash(layers_hash) instance_variables.sort.each do |ivar| hash_or_array = layers_hash[ivar] if hash_or_array.is_a?(Array) instance_variable_set(ivar, LayersList.from_hash_list(hash_or_array)) elsif hash_or_array.is_a?(Hash) obj_class = DNN.const_get(hash_or_array[:class]) obj = obj_class.allocate if obj.is_a?(Chain) obj = obj_class.new obj.load_hash(hash_or_array) instance_variable_set(ivar, obj) else instance_variable_set(ivar, Layers::Layer.from_hash(hash_or_array)) end end end end
to_hash()
click to toggle source
# File lib/dnn/core/models.rb, line 76 def to_hash layers_hash = { class: self.class.name } instance_variables.sort.each do |ivar| obj = instance_variable_get(ivar) if obj.is_a?(Layers::Layer) || obj.is_a?(Chain) layers_hash[ivar] = obj.to_hash elsif obj.is_a?(LayersList) layers_hash[ivar] = obj.to_hash_list end end layers_hash end