class DNN::Layers::GlobalAvgPool2D

Public Instance Methods

build(input_shape) click to toggle source
Calls superclass method DNN::Layers::Layer#build
# File lib/dnn/core/layers/cnn_layers.rb, line 423
def build(input_shape)
  unless input_shape.length == 3
    raise DNNShapeError, "Input shape is #{input_shape}. But input shape must be 3 dimensional."
  end
  super
end
forward(x) click to toggle source
# File lib/dnn/core/layers/cnn_layers.rb, line 430
def forward(x)
  Flatten.(AvgPool2D.(x, @input_shape[0..1]))
end